fsojni commited on
Commit
71db3d7
·
verified ·
1 Parent(s): 2756958
Files changed (1) hide show
  1. app.py +25 -23
app.py CHANGED
@@ -73,6 +73,19 @@ def add_docs(user_id: str, docs: list[str]) -> int:
73
  else torch.cat([store["vecs"], new_vecs])
74
  )
75
  return len(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # ---------- 3. FastAPI layer --------------------------------------------------
78
  class IngestReq(BaseModel):
@@ -104,13 +117,7 @@ def rag(req:QueryReq):
104
  topk = torch.topk(sims, k=min(4, sims.size(0))).indices
105
  context = "\n".join(store["texts"][i] for i in topk.tolist())
106
 
107
- prompt = f"""You are an email assistant.
108
- Use the context to answer.
109
- Context:
110
- {context}
111
-
112
- User question: {req.question}
113
- Assistant:"""
114
 
115
  load_chat()
116
  inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
@@ -128,35 +135,30 @@ def store_doc(doc_text: str, user_id="demo"):
128
  return f"📚 Stored ✅ — KB now has {len(kb[user_id]['texts'])} passage(s)."
129
 
130
  def answer(question: str, user_id="demo"):
131
- """UI callback: retrieve, build prompt, generate answer."""
132
  if not question.strip():
133
  return "⚠️ Please ask a question."
134
  if not kb[user_id]["texts"]:
135
  return "⚠️ No reference passage yet. Add one first."
136
 
137
- # 1️⃣ Retrieve top-k similar chunks (k ≤ #chunks)
138
  q_vec = embed(question)
139
  store = kb[user_id]
140
- sims = torch.matmul(store["vecs"], q_vec) # [N]
141
  k = min(4, sims.numel())
142
  idxs = torch.topk(sims, k=k).indices.tolist()
143
  context = "\n".join(store["texts"][i] for i in idxs)
144
 
145
- # 2️⃣ Build prompt
146
- prompt = f"""You are an email assistant.
147
- Use ONLY the context below to answer.
148
- Context:
149
- {context}
150
 
151
- Question: {question}
152
- Answer:"""
153
-
154
- # 3️⃣ Generate
155
  load_chat()
156
- inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
157
- output = chat_model.generate(**inputs, max_new_tokens=512)
158
- reply = tokenizer.decode(output[0], skip_special_tokens=True)
159
- return reply.split("Answer:", 1)[-1].strip()
 
160
 
161
  # ---- UI layout (feel free to tweak cosmetics) -----------------------------
162
  with gr.Blocks() as demo:
 
73
  else torch.cat([store["vecs"], new_vecs])
74
  )
75
  return len(docs)
76
+ # ----- Qwen-chat prompt helper ---------------------------------------------
77
+ def build_qwen_prompt(context: str, user_question: str) -> str:
78
+ """Return a string that follows Qwen-Chat’s template."""
79
+ conversation = [
80
+ {"role": "system",
81
+ "content": "You are an email assistant. Use ONLY the context provided."},
82
+ {"role": "user",
83
+ "content": f"Context:\n{context}\n\n{user_question}"}
84
+ ]
85
+ # add_generation_prompt=True appends the assistant tag
86
+ return tokenizer.apply_chat_template(
87
+ conversation, tokenize=False, add_generation_prompt=True
88
+ )
89
 
90
  # ---------- 3. FastAPI layer --------------------------------------------------
91
  class IngestReq(BaseModel):
 
117
  topk = torch.topk(sims, k=min(4, sims.size(0))).indices
118
  context = "\n".join(store["texts"][i] for i in topk.tolist())
119
 
120
+ prompt = build_qwen_prompt(context, req.question)
 
 
 
 
 
 
121
 
122
  load_chat()
123
  inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
 
135
  return f"📚 Stored ✅ — KB now has {len(kb[user_id]['texts'])} passage(s)."
136
 
137
  def answer(question: str, user_id="demo"):
138
+ """UI callback: retrieve, build prompt with Qwen tags, generate answer."""
139
  if not question.strip():
140
  return "⚠️ Please ask a question."
141
  if not kb[user_id]["texts"]:
142
  return "⚠️ No reference passage yet. Add one first."
143
 
144
+ # 1️⃣ Retrieve top-k similar passages
145
  q_vec = embed(question)
146
  store = kb[user_id]
147
+ sims = torch.matmul(store["vecs"], q_vec) # [N]
148
  k = min(4, sims.numel())
149
  idxs = torch.topk(sims, k=k).indices.tolist()
150
  context = "\n".join(store["texts"][i] for i in idxs)
151
 
152
+ # 2️⃣ Build a Qwen-chat prompt (helper defined earlier)
153
+ prompt = build_qwen_prompt(context, question)
 
 
 
154
 
155
+ # 3️⃣ Generate and strip everything before the assistant tag
 
 
 
156
  load_chat()
157
+ inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
158
+ output = chat_model.generate(**inputs, max_new_tokens=512)
159
+ full = tokenizer.decode(output[0], skip_special_tokens=True)
160
+ reply = full.split("<|im_start|>assistant")[-1].strip()
161
+ return reply
162
 
163
  # ---- UI layout (feel free to tweak cosmetics) -----------------------------
164
  with gr.Blocks() as demo: