fsojni commited on
Commit
9d99ca9
·
verified ·
1 Parent(s): 50e96a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -226,16 +226,17 @@ def ingest(req:IngestReq):
226
  return {"added": len(req.docs)}
227
 
228
  @api.post("/query")
229
- def rag(req:QueryReq):
230
  store = kb.get(req.user_id)
231
  if not store:
232
  raise HTTPException(404, "No knowledge ingested for this user.")
233
- q_vec = embed(req.question)
234
- sims = torch.matmul(store["vecs"], q_vec)
235
- topk = torch.topk(sims, k=min(4, sims.size(0))).indices
 
236
  context = "\n".join(store["texts"][i] for i in topk.tolist())
237
 
238
- SYSTEM_PROMPT = "You are a helpful assistant."
239
  prompt = build_qwen_prompt(SYSTEM_PROMPT, [context], req.question)
240
 
241
  load_chat()
@@ -244,7 +245,6 @@ def rag(req:QueryReq):
244
  return_tensors="pt",
245
  add_special_tokens=False,
246
  )
247
-
248
  if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
249
  tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
250
 
@@ -255,12 +255,11 @@ def rag(req:QueryReq):
255
  max_new_tokens=512,
256
  max_length=MAX_PROMPT_TOKENS + 512,
257
  )
258
-
259
-
260
  full = tokenizer.decode(out[0], skip_special_tokens=True)
261
  ans = full.split("<|im_start|>assistant")[-1].strip()
262
  return {"answer": ans}
263
 
 
264
  # ---------- 5. run both (FastAPI + Gradio) -----------------------------------
265
  if __name__ == "__main__":
266
  # launch Gradio on a background thread
 
226
  return {"added": len(req.docs)}
227
 
228
  @api.post("/query")
229
+ def rag(req: QueryReq):
230
  store = kb.get(req.user_id)
231
  if not store:
232
  raise HTTPException(404, "No knowledge ingested for this user.")
233
+
234
+ q_vec = embed(req.question)
235
+ sims = torch.matmul(store["vecs"], q_vec)
236
+ topk = torch.topk(sims, k=min(4, sims.size(0))).indices
237
  context = "\n".join(store["texts"][i] for i in topk.tolist())
238
 
239
+ SYSTEM_PROMPT = "You are a helpful assistant."
240
  prompt = build_qwen_prompt(SYSTEM_PROMPT, [context], req.question)
241
 
242
  load_chat()
 
245
  return_tensors="pt",
246
  add_special_tokens=False,
247
  )
 
248
  if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
249
  tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
250
 
 
255
  max_new_tokens=512,
256
  max_length=MAX_PROMPT_TOKENS + 512,
257
  )
 
 
258
  full = tokenizer.decode(out[0], skip_special_tokens=True)
259
  ans = full.split("<|im_start|>assistant")[-1].strip()
260
  return {"answer": ans}
261
 
262
+
263
  # ---------- 5. run both (FastAPI + Gradio) -----------------------------------
264
  if __name__ == "__main__":
265
  # launch Gradio on a background thread