Subha95 commited on
Commit
772864e
Β·
verified Β·
1 Parent(s): fb273e0

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +27 -69
chatbot_rag.py CHANGED
@@ -1,16 +1,16 @@
1
-
2
  from langchain_community.vectorstores import Chroma
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.llms import HuggingFacePipeline
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
6
  from langchain.prompts import PromptTemplate
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_core.output_parsers import StrOutputParser
9
- import traceback
10
  import re
11
  import os
12
-
13
  from huggingface_hub import login
 
 
14
  token = os.getenv("HF_TOKEN")
15
  print("πŸ”‘ HF_TOKEN available?", token is not None)
16
  if token:
@@ -38,49 +38,41 @@ def build_qa():
38
  )
39
  print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
40
 
41
- # 3. Load LLM (Phi-3 mini)
42
  print("πŸ”Ή Loading LLM...")
 
43
 
44
- model_id = "microsoft/Phi-3-mini-4k-instruct"
45
-
46
- # Load tokenizer
47
  tokenizer = AutoTokenizer.from_pretrained(model_id)
48
-
49
- # Load model
50
  model = AutoModelForCausalLM.from_pretrained(
51
  model_id,
52
- device_map="auto", # put on GPU if available, else CPU
53
- torch_dtype="auto", # auto precision
54
- trust_remote_code=True # allow custom model code
55
  )
56
  model.config.use_cache = False
57
-
58
- # Create pipeline
59
  pipe = pipeline(
60
  "text-generation",
61
  model=model,
62
  tokenizer=tokenizer,
63
- max_new_tokens=256, # control length of response
64
- temperature=0.2, # more deterministic
65
- do_sample=False, # no randomness (deterministic answers)
66
- top_p=0.9, # nucleus sampling
67
- repetition_penalty=1.2, # πŸš€ reduce loops/repeats
68
  eos_token_id=tokenizer.eos_token_id,
69
  return_full_text=False
70
  )
71
-
72
-
73
- # πŸ”Ή Wrap in LangChain LLM
74
  llm = HuggingFacePipeline(pipeline=pipe)
75
 
76
  # 4. Retriever
77
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
78
 
 
79
  prompt = PromptTemplate(
80
  input_variables=["context", "question"],
81
- template="""
82
- Answer the question using the context below.
83
- Respond in ONE short factual sentence only.
84
  If you don't know, say "I don't know."
85
 
86
  Context:
@@ -92,51 +84,21 @@ def build_qa():
92
  Answer:""",
93
  )
94
 
95
-
96
-
97
-
98
- # 6. Helper functions
99
  def format_docs(docs):
100
- """Join document contents into a single string, skipping empty ones."""
101
- texts = []
102
- for doc in docs:
103
- if doc.page_content and isinstance(doc.page_content, str):
104
- texts.append(doc.page_content.strip())
105
  return "\n".join(texts)
106
 
107
-
108
-
109
  def hf_to_str(x):
110
- """Convert Hugging Face pipeline output to clean plain text."""
111
  if isinstance(x, list) and "generated_text" in x[0]:
112
  text = x[0]["generated_text"]
113
  else:
114
  text = str(x)
115
-
116
- # Remove code-like patterns (imports, defs, classes, etc.)
117
- text = re.sub(r"(from\s+\w+\s+import\s+.*|import\s+\w+.*)", "", text)
118
- text = re.sub(r"def\s+\w+\(.*?\):.*", "", text, flags=re.DOTALL)
119
- text = re.sub(r"class\s+\w+.*?:.*", "", text, flags=re.DOTALL)
120
- text = re.sub(r"text\s*\+=.*", "", text)
121
-
122
- # Remove markdown/code fences & quotes
123
- text = text.replace("```", "").replace("'''", "").replace('"""', "").replace("\\n", " ")
124
-
125
- # Normalize whitespace
126
- text = re.sub(r"\s+", " ", text)
127
-
128
- # Deduplicate repeated sentences
129
- sentences = []
130
- for s in re.split(r"(?<=[.!?])\s+", text):
131
- if s and s not in sentences:
132
- sentences.append(s)
133
- text = " ".join(sentences)
134
-
135
- return text.strip()
136
-
137
-
138
-
139
- # 7. RAG chain
140
  rag_chain = (
141
  {
142
  "context": retriever | format_docs,
@@ -148,7 +110,6 @@ def build_qa():
148
  | StrOutputParser()
149
  )
150
 
151
-
152
  print("βœ… QA pipeline ready.")
153
  return rag_chain
154
 
@@ -165,14 +126,11 @@ except Exception as e:
165
 
166
 
167
  def get_answer(query: str) -> str:
168
- """
169
- Run a query against the QA pipeline and return the answer text.
170
- """
171
  if qa_pipeline is None:
172
  return "⚠️ QA pipeline not initialized."
173
-
174
  try:
175
- result = qa_pipeline.invoke(query) # for LCEL chain
176
  return result
177
  except Exception as e:
178
  return f"❌ QA run failed: {e}"
 
 
1
  from langchain_community.vectorstores import Chroma
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.llms import HuggingFacePipeline
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  from langchain.prompts import PromptTemplate
6
  from langchain_core.runnables import RunnablePassthrough
7
  from langchain_core.output_parsers import StrOutputParser
 
8
  import re
9
  import os
10
+ import traceback
11
  from huggingface_hub import login
12
+
13
+
14
  token = os.getenv("HF_TOKEN")
15
  print("πŸ”‘ HF_TOKEN available?", token is not None)
16
  if token:
 
38
  )
39
  print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
40
 
41
+ # 3. Load LLM (Phi-3.5-mini-instruct)
42
  print("πŸ”Ή Loading LLM...")
43
+ model_id = "microsoft/Phi-3.5-mini-instruct"
44
 
 
 
 
45
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model_id,
48
+ device_map="auto",
49
+ torch_dtype="auto",
50
+ trust_remote_code=True
51
  )
52
  model.config.use_cache = False
53
+
 
54
  pipe = pipeline(
55
  "text-generation",
56
  model=model,
57
  tokenizer=tokenizer,
58
+ max_new_tokens=80, # shorter answers
59
+ temperature=0.2, # deterministic
60
+ do_sample=False,
61
+ repetition_penalty=1.2,
 
62
  eos_token_id=tokenizer.eos_token_id,
63
  return_full_text=False
64
  )
65
+
 
 
66
  llm = HuggingFacePipeline(pipeline=pipe)
67
 
68
  # 4. Retriever
69
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
70
 
71
+ # 5. Prompt
72
  prompt = PromptTemplate(
73
  input_variables=["context", "question"],
74
+ template="""Answer the question using the context below.
75
+ Respond in ONE short factual sentence only.
 
76
  If you don't know, say "I don't know."
77
 
78
  Context:
 
84
  Answer:""",
85
  )
86
 
87
+ # 6. Helper
 
 
 
88
  def format_docs(docs):
89
+ texts = [doc.page_content.strip() for doc in docs if doc.page_content]
 
 
 
 
90
  return "\n".join(texts)
91
 
 
 
92
  def hf_to_str(x):
 
93
  if isinstance(x, list) and "generated_text" in x[0]:
94
  text = x[0]["generated_text"]
95
  else:
96
  text = str(x)
97
+ text = re.sub(r"\s+", " ", text).strip()
98
+ # βœ… Only keep first sentence
99
+ return re.split(r"(?<=[.!?])\s+", text)[0]
100
+
101
+ # 7. Chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  rag_chain = (
103
  {
104
  "context": retriever | format_docs,
 
110
  | StrOutputParser()
111
  )
112
 
 
113
  print("βœ… QA pipeline ready.")
114
  return rag_chain
115
 
 
126
 
127
 
128
  def get_answer(query: str) -> str:
129
+ """Run a query against the QA pipeline and return the answer text."""
 
 
130
  if qa_pipeline is None:
131
  return "⚠️ QA pipeline not initialized."
 
132
  try:
133
+ result = qa_pipeline.invoke(query)
134
  return result
135
  except Exception as e:
136
  return f"❌ QA run failed: {e}"