Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,12 +41,12 @@ def load_embedder():
|
|
| 41 |
emb_model.eval()
|
| 42 |
|
| 43 |
@torch.no_grad()
|
| 44 |
-
def embed(text:
|
|
|
|
| 45 |
load_embedder()
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
return F.normalize(vec, dim=-1).cpu()
|
| 50 |
|
| 51 |
# ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
|
| 52 |
# ---------- 2. Tiny in-memory knowledge-base -------------------------------
|
|
@@ -67,7 +67,7 @@ def add_docs(user_id: str, docs: list[str]) -> int:
|
|
| 67 |
return 0
|
| 68 |
|
| 69 |
load_embedder() # lazy-load once
|
| 70 |
-
new_vecs = torch.stack([embed(t) for t in docs])
|
| 71 |
store = kb[user_id] # auto-creates via defaultdict
|
| 72 |
store["texts"].extend(docs)
|
| 73 |
store["vecs"] = (
|
|
@@ -119,7 +119,7 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
|
|
| 119 |
context_list = [context]
|
| 120 |
# 1. Retrieve top-k similar passages
|
| 121 |
if history == "Some":
|
| 122 |
-
q_vec = embed(question)
|
| 123 |
store = kb[user_id]
|
| 124 |
sims = torch.matmul(store["vecs"], q_vec) # [N]
|
| 125 |
k = min(4, sims.numel())
|
|
@@ -130,23 +130,7 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
|
|
| 130 |
context_list += store["texts"]
|
| 131 |
|
| 132 |
# 2. Build a Qwen-chat prompt (helper defined earlier)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
prompt = build_qwen_prompt(system, context_list, question)
|
| 136 |
-
tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
| 137 |
-
|
| 138 |
-
if tokens.input_ids.size(1) > MAX_PROMPT_TOKENS:
|
| 139 |
-
# keep the last MAX_PROMPT_TOKENS tokens (most recent content)
|
| 140 |
-
tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
|
| 141 |
-
|
| 142 |
-
tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
|
| 143 |
-
|
| 144 |
-
output = chat_model.generate(
|
| 145 |
-
**tokens,
|
| 146 |
-
max_new_tokens=512,
|
| 147 |
-
max_length=MAX_PROMPT_TOKENS + 512,
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
|
| 151 |
# 3. Generate and strip everything before the assistant tag
|
| 152 |
load_chat()
|
|
@@ -247,4 +231,4 @@ if __name__ == "__main__":
|
|
| 247 |
# launch Gradio on a background thread
|
| 248 |
demo.queue().launch(share=False, prevent_thread_lock=True)
|
| 249 |
# then start FastAPI (uvicorn blocks main thread)
|
| 250 |
-
uvicorn.run(api, host="0.0.0.0", port=8000)
|
|
|
|
| 41 |
emb_model.eval()
|
| 42 |
|
| 43 |
@torch.no_grad()
|
| 44 |
+
def embed(text:str)->torch.Tensor:
|
| 45 |
+
"""Return L2-normalised embedding vector."""
|
| 46 |
load_embedder()
|
| 47 |
+
inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
|
| 48 |
+
vec = emb_model(**inputs).last_hidden_state[:, 0] # CLS pooling
|
| 49 |
+
return F.normalize(vec, dim=-1).squeeze(0)
|
|
|
|
| 50 |
|
| 51 |
# ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
|
| 52 |
# ---------- 2. Tiny in-memory knowledge-base -------------------------------
|
|
|
|
| 67 |
return 0
|
| 68 |
|
| 69 |
load_embedder() # lazy-load once
|
| 70 |
+
new_vecs = torch.stack([embed(t) for t in docs])
|
| 71 |
store = kb[user_id] # auto-creates via defaultdict
|
| 72 |
store["texts"].extend(docs)
|
| 73 |
store["vecs"] = (
|
|
|
|
| 119 |
context_list = [context]
|
| 120 |
# 1. Retrieve top-k similar passages
|
| 121 |
if history == "Some":
|
| 122 |
+
q_vec = embed(question)
|
| 123 |
store = kb[user_id]
|
| 124 |
sims = torch.matmul(store["vecs"], q_vec) # [N]
|
| 125 |
k = min(4, sims.numel())
|
|
|
|
| 130 |
context_list += store["texts"]
|
| 131 |
|
| 132 |
# 2. Build a Qwen-chat prompt (helper defined earlier)
|
| 133 |
+
prompt = build_qwen_prompt(system, context_list, question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# 3. Generate and strip everything before the assistant tag
|
| 136 |
load_chat()
|
|
|
|
| 231 |
# launch Gradio on a background thread
|
| 232 |
demo.queue().launch(share=False, prevent_thread_lock=True)
|
| 233 |
# then start FastAPI (uvicorn blocks main thread)
|
| 234 |
+
uvicorn.run(api, host="0.0.0.0", port=8000)
|