fsojni commited on
Commit
2756958
·
verified ·
1 Parent(s): 4d4bca5

GPT生的鬼東西 試記憶+RAG 大概會炸 locally有備份

Browse files
Files changed (1) hide show
  1. app.py +189 -31
app.py CHANGED
@@ -1,31 +1,189 @@
1
- import os
2
- import torch
3
- import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
-
6
- # 全局變數初始化為 None
7
- tokenizer = None
8
- model = None
9
-
10
- def reply(prompt, model_id="QWen/Qwen1.5-7B-Chat", api_token=None):
11
- try:
12
- if api_token is None:
13
- api_token = os.getenv("HF_token")
14
- except Exception as e:
15
- return f"無法取得 API token。\n錯誤訊息:{str(e)}"
16
-
17
- global tokenizer, model
18
- try:
19
- if tokenizer is None or model is None:
20
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=api_token)
21
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, token=api_token)
22
-
23
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
24
- outputs = model.generate(**inputs, max_new_tokens=8192)
25
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
26
- except RuntimeError as e:
27
- return f"執行時錯誤:{str(e)}。"
28
- except Exception as e:
29
- return f"發生錯誤:{str(e)}"
30
-
31
- gr.Interface(fn=reply, inputs="text", outputs="text").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query)
3
+ """
4
+
5
+ # ---------- 1. imports & global helpers -------------
6
+ import os, math, torch, uvicorn, gradio as gr
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+ from transformers import (
10
+ AutoTokenizer, AutoModelForCausalLM,
11
+ AutoTokenizer, AutoModel, AutoConfig
12
+ )
13
+ import torch.nn.functional as F
14
+ from collections import defaultdict
15
+ HF_TOKEN = os.getenv("HF_token")
16
+ CHAT_MODEL_ID = "QWen/Qwen1.5-7B-Chat"
17
+ EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
18
+
19
+ # --- lazy loaders (unchanged) -------------------------------------------------
20
+ tokenizer, chat_model = None, None
21
+ emb_tokenizer, emb_model = None, None
22
+
23
+ def load_chat():
24
+ global tokenizer, chat_model
25
+ if tokenizer is None:
26
+ tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN)
27
+ chat_model = AutoModelForCausalLM.from_pretrained(
28
+ CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN
29
+ )
30
+
31
+ def load_embedder():
32
+ global emb_tokenizer, emb_model
33
+ if emb_tokenizer is None:
34
+ emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
35
+ cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
36
+ emb_model = AutoModel.from_pretrained(
37
+ EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN
38
+ )
39
+ emb_model.eval()
40
+
41
+ @torch.no_grad()
42
+ def embed(text:str)->torch.Tensor:
43
+ """Return L2-normalised embedding vector."""
44
+ load_embedder()
45
+ inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
46
+ vec = emb_model(**inputs).last_hidden_state[:, 0] # CLS pooling
47
+ return F.normalize(vec, dim=-1).squeeze(0)
48
+
49
+ # ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
50
+ # ---------- 2. Tiny in-memory knowledge-base -------------------------------
51
+ # One dict entry per user_id.
52
+ # Each entry holds:
53
+ # • "texts": list[str] – the raw passages we ingested
54
+ # • "vecs" : Tensor[N,d] – their embeddings stacked row-wise
55
+ # --------------------------------------------------------------------------
56
+
57
+
58
+ kb = defaultdict(lambda: {"texts": [], "vecs": None})
59
+
60
+ def add_docs(user_id: str, docs: list[str]) -> int:
61
+ """Embed *docs* and append them to the KB for *user_id*.
62
+ Returns the number of docs actually stored."""
63
+ docs = [t for t in docs if t.strip()] # skip blanks
64
+ if not docs:
65
+ return 0
66
+
67
+ load_embedder() # lazy-load once
68
+ new_vecs = torch.stack([embed(t) for t in docs])
69
+ store = kb[user_id] # auto-creates via defaultdict
70
+ store["texts"].extend(docs)
71
+ store["vecs"] = (
72
+ new_vecs if store["vecs"] is None
73
+ else torch.cat([store["vecs"], new_vecs])
74
+ )
75
+ return len(docs)
76
+
77
+ # ---------- 3. FastAPI layer --------------------------------------------------
78
+ class IngestReq(BaseModel):
79
+ user_id:str
80
+ docs:list[str]
81
+
82
+ class QueryReq(BaseModel):
83
+ user_id:str
84
+ question:str
85
+
86
+ api = FastAPI()
87
+
88
+ @api.post("/ingest")
89
+ def ingest(req:IngestReq):
90
+ load_embedder()
91
+ vecs = torch.stack([embed(t) for t in req.docs])
92
+ store = kb.setdefault(req.user_id, {"texts":[], "vecs":None})
93
+ store["texts"].extend(req.docs)
94
+ store["vecs"] = vecs if store["vecs"] is None else torch.cat([store["vecs"], vecs])
95
+ return {"added": len(req.docs)}
96
+
97
+ @api.post("/query")
98
+ def rag(req:QueryReq):
99
+ store = kb.get(req.user_id)
100
+ if not store:
101
+ raise HTTPException(404, "No knowledge ingested for this user.")
102
+ q_vec = embed(req.question)
103
+ sims = torch.matmul(store["vecs"], q_vec)
104
+ topk = torch.topk(sims, k=min(4, sims.size(0))).indices
105
+ context = "\n".join(store["texts"][i] for i in topk.tolist())
106
+
107
+ prompt = f"""You are an email assistant.
108
+ Use the context to answer.
109
+ Context:
110
+ {context}
111
+
112
+ User question: {req.question}
113
+ Assistant:"""
114
+
115
+ load_chat()
116
+ inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
117
+ out = chat_model.generate(**inputs, max_new_tokens=512)
118
+ ans = tokenizer.decode(out[0], skip_special_tokens=True).split("Assistant:",1)[-1].strip()
119
+ return {"answer": ans}
120
+
121
+ # ---------- 4. Gradio playground (same UI as before) --------------------------
122
+ # ---------- 4. Gradio playground ------------------------------------------
123
+ def store_doc(doc_text: str, user_id="demo"):
124
+ """UI callback: take the textbox content and shove it into the KB."""
125
+ n = add_docs(user_id, [doc_text])
126
+ if n == 0:
127
+ return "⚠️ Nothing stored (empty input)."
128
+ return f"📚 Stored ✅ — KB now has {len(kb[user_id]['texts'])} passage(s)."
129
+
130
+ def answer(question: str, user_id="demo"):
131
+ """UI callback: retrieve, build prompt, generate answer."""
132
+ if not question.strip():
133
+ return "⚠️ Please ask a question."
134
+ if not kb[user_id]["texts"]:
135
+ return "⚠️ No reference passage yet. Add one first."
136
+
137
+ # 1️⃣ Retrieve top-k similar chunks (k ≤ #chunks)
138
+ q_vec = embed(question)
139
+ store = kb[user_id]
140
+ sims = torch.matmul(store["vecs"], q_vec) # [N]
141
+ k = min(4, sims.numel())
142
+ idxs = torch.topk(sims, k=k).indices.tolist()
143
+ context = "\n".join(store["texts"][i] for i in idxs)
144
+
145
+ # 2️⃣ Build prompt
146
+ prompt = f"""You are an email assistant.
147
+ Use ONLY the context below to answer.
148
+ Context:
149
+ {context}
150
+
151
+ Question: {question}
152
+ Answer:"""
153
+
154
+ # 3️⃣ Generate
155
+ load_chat()
156
+ inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
157
+ output = chat_model.generate(**inputs, max_new_tokens=512)
158
+ reply = tokenizer.decode(output[0], skip_special_tokens=True)
159
+ return reply.split("Answer:", 1)[-1].strip()
160
+
161
+ # ---- UI layout (feel free to tweak cosmetics) -----------------------------
162
+ with gr.Blocks() as demo:
163
+ gr.Markdown("### 📥 Tiny-RAG playground  –  1) paste a passage → store   2) ask a question")
164
+
165
+ # ---- passage ingestion ----
166
+ with gr.Row():
167
+ passage_box = gr.Textbox(lines=6, label="Reference passage")
168
+ store_btn = gr.Button("➕ Store passage")
169
+ status_box = gr.Markdown()
170
+ store_btn.click(fn=store_doc,
171
+ inputs=passage_box,
172
+ outputs=status_box)
173
+
174
+ # ---- Q & A ----
175
+ question_box = gr.Textbox(lines=2, label="Ask a question")
176
+ answer_btn = gr.Button("🤖 Answer")
177
+ answer_box = gr.Textbox(lines=6, label="Assistant reply")
178
+
179
+ answer_btn.click(fn=answer,
180
+ inputs=question_box,
181
+ outputs=answer_box)
182
+
183
+
184
+ # ---------- 5. run both (FastAPI + Gradio) -----------------------------------
185
+ if __name__ == "__main__":
186
+ # launch Gradio on a background thread
187
+ demo.queue().launch(share=False, prevent_thread_lock=True)
188
+ # then start FastAPI (uvicorn blocks main thread)
189
+ uvicorn.run(api, host="0.0.0.0", port=8000)