Subha95 commited on
Commit
be9b0cf
·
verified ·
1 Parent(s): 92551e7

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +21 -12
chatbot_rag.py CHANGED
@@ -41,27 +41,36 @@ def build_qa():
41
  # 3. Load LLM (Phi-3 mini)
42
  print("🔹 Loading LLM...")
43
 
44
- model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
45
-
46
- # 🔹 Load tokenizer & model
47
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
48
  model = AutoModelForCausalLM.from_pretrained(
49
- model_name,
50
- device_map="auto", # auto selects GPU if available, else CPU
51
- torch_dtype="auto" # picks best dtype automatically
 
52
  )
53
 
54
- # 🔹 Create HF pipeline
55
  pipe = pipeline(
56
  "text-generation",
57
  model=model,
58
  tokenizer=tokenizer,
59
- max_new_tokens=512, # adjust output length
60
- temperature=0.2,
61
- do_sample=False, # creativity vs determinism
62
- top_p=0.9 # nucleus sampling
 
 
 
63
  )
64
 
 
 
 
65
  # 🔹 Wrap in LangChain LLM
66
  llm = HuggingFacePipeline(pipeline=pipe)
67
 
 
41
  # 3. Load LLM (Phi-3 mini)
42
  print("🔹 Loading LLM...")
43
 
44
+ model_id = "microsoft/phi-2"
45
+
46
+ # Load tokenizer
47
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
48
+
49
+ # Load model
50
  model = AutoModelForCausalLM.from_pretrained(
51
+ model_id,
52
+ device_map="auto", # put on GPU if available, else CPU
53
+ torch_dtype="auto", # auto precision
54
+ trust_remote_code=True # allow custom model code
55
  )
56
 
57
+ # Create pipeline
58
  pipe = pipeline(
59
  "text-generation",
60
  model=model,
61
  tokenizer=tokenizer,
62
+ max_new_tokens=256, # control length of response
63
+ temperature=0.2, # more deterministic
64
+ do_sample=False, # no randomness (deterministic answers)
65
+ top_p=0.9, # nucleus sampling
66
+ repetition_penalty=1.2, # 🚀 reduce loops/repeats
67
+ eos_token_id=tokenizer.eos_token_id,
68
+ return_full_text=False
69
  )
70
 
71
+ # Wrap into LangChain LLM
72
+ llm = HuggingFacePipeline(pipeline=pipe)
73
+
74
  # 🔹 Wrap in LangChain LLM
75
  llm = HuggingFacePipeline(pipeline=pipe)
76