fsojni commited on
Commit
4872cd0
·
verified ·
1 Parent(s): 6947209

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -25
app.py CHANGED
@@ -41,12 +41,12 @@ def load_embedder():
41
  emb_model.eval()
42
 
43
  @torch.no_grad()
44
- def embed(text: str) -> torch.Tensor:
 
45
  load_embedder()
46
- with torch.no_grad():
47
- inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
48
- vec = emb_model(**inputs).last_hidden_state[:, 0]
49
- return F.normalize(vec, dim=-1).cpu()
50
 
51
  # ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
52
  # ---------- 2. Tiny in-memory knowledge-base -------------------------------
@@ -67,7 +67,7 @@ def add_docs(user_id: str, docs: list[str]) -> int:
67
  return 0
68
 
69
  load_embedder() # lazy-load once
70
- new_vecs = torch.stack([embed(t) for t in docs]).cpu()
71
  store = kb[user_id] # auto-creates via defaultdict
72
  store["texts"].extend(docs)
73
  store["vecs"] = (
@@ -119,7 +119,7 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
119
  context_list = [context]
120
  # 1. Retrieve top-k similar passages
121
  if history == "Some":
122
- q_vec = embed(question).cpu()
123
  store = kb[user_id]
124
  sims = torch.matmul(store["vecs"], q_vec) # [N]
125
  k = min(4, sims.numel())
@@ -130,23 +130,7 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
130
  context_list += store["texts"]
131
 
132
  # 2. Build a Qwen-chat prompt (helper defined earlier)
133
- MAX_PROMPT_TOKENS = 8192 # 8 k is ~4 GB KV-cache
134
-
135
- prompt = build_qwen_prompt(system, context_list, question)
136
- tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
137
-
138
- if tokens.input_ids.size(1) > MAX_PROMPT_TOKENS:
139
- # keep the last MAX_PROMPT_TOKENS tokens (most recent content)
140
- tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
141
-
142
- tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
143
-
144
- output = chat_model.generate(
145
- **tokens,
146
- max_new_tokens=512,
147
- max_length=MAX_PROMPT_TOKENS + 512,
148
- )
149
-
150
 
151
  # 3. Generate and strip everything before the assistant tag
152
  load_chat()
@@ -247,4 +231,4 @@ if __name__ == "__main__":
247
  # launch Gradio on a background thread
248
  demo.queue().launch(share=False, prevent_thread_lock=True)
249
  # then start FastAPI (uvicorn blocks main thread)
250
- uvicorn.run(api, host="0.0.0.0", port=8000)
 
41
  emb_model.eval()
42
 
43
  @torch.no_grad()
44
+ def embed(text:str)->torch.Tensor:
45
+ """Return L2-normalised embedding vector."""
46
  load_embedder()
47
+ inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
48
+ vec = emb_model(**inputs).last_hidden_state[:, 0] # CLS pooling
49
+ return F.normalize(vec, dim=-1).squeeze(0)
 
50
 
51
  # ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
52
  # ---------- 2. Tiny in-memory knowledge-base -------------------------------
 
67
  return 0
68
 
69
  load_embedder() # lazy-load once
70
+ new_vecs = torch.stack([embed(t) for t in docs])
71
  store = kb[user_id] # auto-creates via defaultdict
72
  store["texts"].extend(docs)
73
  store["vecs"] = (
 
119
  context_list = [context]
120
  # 1. Retrieve top-k similar passages
121
  if history == "Some":
122
+ q_vec = embed(question)
123
  store = kb[user_id]
124
  sims = torch.matmul(store["vecs"], q_vec) # [N]
125
  k = min(4, sims.numel())
 
130
  context_list += store["texts"]
131
 
132
  # 2. Build a Qwen-chat prompt (helper defined earlier)
133
+ prompt = build_qwen_prompt(system, context_list, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # 3. Generate and strip everything before the assistant tag
136
  load_chat()
 
231
  # launch Gradio on a background thread
232
  demo.queue().launch(share=False, prevent_thread_lock=True)
233
  # then start FastAPI (uvicorn blocks main thread)
234
+ uvicorn.run(api, host="0.0.0.0", port=8000)