Cyantist8208 commited on
Commit
6b79d5e
·
1 Parent(s): 09c5a80

llama good

Browse files
Files changed (1) hide show
  1. app.py +13 -62
app.py CHANGED
@@ -77,7 +77,7 @@ def add_docs(user_id: str, docs: list[str]) -> int:
77
  )
78
  return len(docs)
79
  # ----- Qwen-chat prompt helper ---------------------------------------------
80
- def build_qwen_prompt(system: str, context: list[str], user_question: str) -> str:
81
  """Return a Qwen-style prompt with multiple context items."""
82
  load_chat() # 確保 tokenizer 載入
83
 
@@ -93,12 +93,19 @@ def build_qwen_prompt(system: str, context: list[str], user_question: str) -> st
93
  # 加入最終問題
94
  conversation.append({"role": "user", "content": user_question})
95
 
96
- return tokenizer.apply_chat_template(
97
- conversation, tokenize=False, add_generation_prompt=True
98
- )
 
 
 
 
 
 
 
 
99
 
100
  # ---------- 4. Gradio playground (same UI as before) --------------------------
101
- # ---------- 4. Gradio playground ------------------------------------------
102
  def store_doc(doc_text: str, user_id="demo"):
103
  """UI callback: take the textbox content and shove it into the KB."""
104
  try:
@@ -137,7 +144,7 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
137
  context_list += store["texts"]
138
 
139
  # 2. Build a Qwen-chat prompt (helper defined earlier)
140
- prompt = build_qwen_prompt(system, context_list, question)
141
 
142
  # 3. Tokenise & cap
143
  load_chat()
@@ -211,62 +218,6 @@ with gr.Blocks() as demo:
211
  outputs=answer_box
212
  )
213
 
214
- # ---------- 3. FastAPI layer --------------------------------------------------
215
- class IngestReq(BaseModel):
216
- user_id:str
217
- docs:list[str]
218
-
219
- class QueryReq(BaseModel):
220
- user_id:str
221
- question:str
222
-
223
- api = FastAPI()
224
- api = gr.mount_gradio_app(api, demo, path="/")
225
-
226
- @api.post("/ingest")
227
- def ingest(req:IngestReq):
228
- load_embedder()
229
- vecs = torch.stack([embed(t) for t in req.docs])
230
- store = kb.setdefault(req.user_id, {"texts":[], "vecs":None})
231
- store["texts"].extend(req.docs)
232
- store["vecs"] = vecs if store["vecs"] is None else torch.cat([store["vecs"], vecs])
233
- return {"added": len(req.docs)}
234
-
235
- @api.post("/query")
236
- def rag(req: QueryReq):
237
- store = kb.get(req.user_id)
238
- if not store:
239
- raise HTTPException(404, "No knowledge ingested for this user.")
240
-
241
- q_vec = embed(req.question)
242
- sims = torch.matmul(store["vecs"], q_vec)
243
- topk = torch.topk(sims, k=min(4, sims.size(0))).indices
244
- context = "\n".join(store["texts"][i] for i in topk.tolist())
245
-
246
- SYSTEM_PROMPT = "You are a helpful assistant."
247
- prompt = build_qwen_prompt(SYSTEM_PROMPT, [context], req.question)
248
-
249
- load_chat()
250
- tokens = tokenizer(
251
- prompt,
252
- return_tensors="pt",
253
- add_special_tokens=False,
254
- )
255
- if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
256
- tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
257
-
258
- tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
259
-
260
- out = chat_model.generate(
261
- **tokens,
262
- max_new_tokens=512,
263
- max_length=MAX_PROMPT_TOKENS + 512,
264
- )
265
- full = tokenizer.decode(out[0], skip_special_tokens=True)
266
- ans = full.split("<|im_start|>assistant")[-1].strip()
267
- return {"answer": ans}
268
-
269
-
270
  # ---------- 5. run both (FastAPI + Gradio) -----------------------------------
271
  if __name__ == "__main__":
272
  # launch Gradio on a background thread
 
77
  )
78
  return len(docs)
79
  # ----- Qwen-chat prompt helper ---------------------------------------------
80
+ def build_llm_prompt(system: str, context: list[str], user_question: str) -> str:
81
  """Return a Qwen-style prompt with multiple context items."""
82
  load_chat() # 確保 tokenizer 載入
83
 
 
93
  # 加入最終問題
94
  conversation.append({"role": "user", "content": user_question})
95
 
96
+ prompt = ""
97
+ for turn in conversation:
98
+ role = turn["role"]
99
+ content = turn["content"].strip()
100
+ if role == "system":
101
+ prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
102
+ elif role == "user":
103
+ prompt += f"[INST] {content.strip()} [/INST]\n"
104
+ elif role == "assistant":
105
+ prompt += f"{content.strip()}\n"
106
+ return prompt
107
 
108
  # ---------- 4. Gradio playground (same UI as before) --------------------------
 
109
  def store_doc(doc_text: str, user_id="demo"):
110
  """UI callback: take the textbox content and shove it into the KB."""
111
  try:
 
144
  context_list += store["texts"]
145
 
146
  # 2. Build a Qwen-chat prompt (helper defined earlier)
147
+ prompt = build_llm_prompt(system, context_list, question)
148
 
149
  # 3. Tokenise & cap
150
  load_chat()
 
218
  outputs=answer_box
219
  )
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # ---------- 5. run both (FastAPI + Gradio) -----------------------------------
222
  if __name__ == "__main__":
223
  # launch Gradio on a background thread