GmailAddOn / app.py
Cyantist8208's picture
調一下順序
9148198
raw
history blame
8.53 kB
"""
app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query)
"""
# ---------- 1. imports & global helpers -------------
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math, torch, uvicorn, gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoTokenizer, AutoModel, AutoConfig
)
import torch.nn.functional as F
from collections import defaultdict
HF_TOKEN = os.getenv("HF_token")
CHAT_MODEL_ID = "QWen/Qwen1.5-7B-Chat"
EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
# --- lazy loaders (unchanged) -------------------------------------------------
tokenizer, chat_model = None, None
emb_tokenizer, emb_model = None, None
def load_chat():
global tokenizer, chat_model
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN)
chat_model = AutoModelForCausalLM.from_pretrained(
CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN
)
def load_embedder():
global emb_tokenizer, emb_model
if emb_tokenizer is None:
emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
emb_model = AutoModel.from_pretrained(
EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN
)
emb_model.eval()
@torch.no_grad()
def embed(text:str)->torch.Tensor:
"""Return L2-normalised embedding vector."""
load_embedder()
inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
vec = emb_model(**inputs).last_hidden_state[:, 0] # CLS pooling
return F.normalize(vec, dim=-1).squeeze(0)
# ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
# ---------- 2. Tiny in-memory knowledge-base -------------------------------
# One dict entry per user_id.
# Each entry holds:
# • "texts": list[str] – the raw passages we ingested
# • "vecs" : Tensor[N,d] – their embeddings stacked row-wise
# --------------------------------------------------------------------------
kb = defaultdict(lambda: {"texts": [], "vecs": None})
def add_docs(user_id: str, docs: list[str]) -> int:
"""Embed *docs* and append them to the KB for *user_id*.
Returns the number of docs actually stored."""
docs = [t for t in docs if t.strip()] # skip blanks
if not docs:
return 0
load_embedder() # lazy-load once
new_vecs = torch.stack([embed(t) for t in docs])
store = kb[user_id] # auto-creates via defaultdict
store["texts"].extend(docs)
store["vecs"] = (
new_vecs if store["vecs"] is None
else torch.cat([store["vecs"], new_vecs])
)
return len(docs)
# ----- Qwen-chat prompt helper ---------------------------------------------
def build_qwen_prompt(context: str, user_question: str) -> str:
"""Return a string that follows Qwen-Chat’s template."""
load_chat() # ← make sure tokenizer is ready
if context != "":
context = f"Context:\n{context}\n\n"
conversation = [
{"role": "system",
"content": "You are an email assistant. Use ONLY the context provided."},
{"role": "user",
"content": f"{context}{user_question}"}
]
return tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
# ---------- 4. Gradio playground (same UI as before) --------------------------
# ---------- 4. Gradio playground ------------------------------------------
def store_doc(doc_text: str, user_id="demo"):
"""UI callback: take the textbox content and shove it into the KB."""
n = add_docs(user_id, [doc_text])
if n == 0:
return "Nothing stored (empty input)."
return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)."
def answer(question: str, user_id="demo", history="None"):
"""UI callback: retrieve, build prompt with Qwen tags, generate answer."""
try:
if not question.strip():
return "Please ask a question."
if history and not kb[user_id]["texts"]:
return "No reference passage yet. Add one first."
context = ""
# 1. Retrieve top-k similar passages
if history == "Some":
q_vec = embed(question)
store = kb[user_id]
sims = torch.matmul(store["vecs"], q_vec) # [N]
k = min(4, sims.numel())
idxs = torch.topk(sims, k=k).indices.tolist()
context = "\n".join(store["texts"][i] for i in idxs)
elif history == "All":
context = "\n".join(store["texts"])
# 2. Build a Qwen-chat prompt (helper defined earlier)
prompt = build_qwen_prompt(context, question)
# 3. Generate and strip everything before the assistant tag
load_chat()
inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
output = chat_model.generate(**inputs, max_new_tokens=512)
full = tokenizer.decode(output[0], skip_special_tokens=True)
reply = full.split("<|im_start|>assistant")[-1].strip()
return reply
except Exception as e:
return f"Error in app.py: {e}"
finally:
torch.cuda.empty_cache()
# ---- UI layout (feel free to tweak cosmetics) -----------------------------
with gr.Blocks() as demo:
gr.Markdown("### Tiny-RAG playground &nbsp;–&nbsp; 1) paste a passage → store&nbsp;&nbsp; 2) ask a question")
# ---- passage ingestion ----
with gr.Row():
passage_box = gr.Textbox(lines=6, label="Reference passage")
store_btn = gr.Button("Store passage")
status_box = gr.Markdown()
store_btn.click(fn=store_doc,
inputs=passage_box,
outputs=status_box)
# ---- Q & A ----
question_box = gr.Textbox(lines=2, label="Ask a question")
user_id_box = gr.Textbox(value="demo", label="User ID")
history_cb = gr.Textbox(value="None", label="Use chat history")
answer_btn = gr.Button("Answer")
answer_box = gr.Textbox(lines=6, label="Assistant reply")
answer_btn.click(
fn=answer,
inputs=[question_box, user_id_box, history_cb],
outputs=answer_box
)
# ---------- 3. FastAPI layer --------------------------------------------------
class IngestReq(BaseModel):
user_id:str
docs:list[str]
class QueryReq(BaseModel):
user_id:str
question:str
api = FastAPI()
api = gr.mount_gradio_app(api, demo, path="/")
@api.post("/ingest")
def ingest(req:IngestReq):
load_embedder()
vecs = torch.stack([embed(t) for t in req.docs])
store = kb.setdefault(req.user_id, {"texts":[], "vecs":None})
store["texts"].extend(req.docs)
store["vecs"] = vecs if store["vecs"] is None else torch.cat([store["vecs"], vecs])
return {"added": len(req.docs)}
@api.post("/query")
def rag(req:QueryReq):
store = kb.get(req.user_id)
if not store:
raise HTTPException(404, "No knowledge ingested for this user.")
q_vec = embed(req.question)
sims = torch.matmul(store["vecs"], q_vec)
topk = torch.topk(sims, k=min(4, sims.size(0))).indices
context = "\n".join(store["texts"][i] for i in topk.tolist())
prompt = build_qwen_prompt(context, req.question)
load_chat()
inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
out = chat_model.generate(**inputs, max_new_tokens=512)
full = tokenizer.decode(out[0], skip_special_tokens=True)
ans = full.split("<|im_start|>assistant")[-1].strip()
return {"answer": ans}
@api.post("/clear")
def clear_kb(user_id: str = "demo"):
if user_id in kb:
kb[user_id]["texts"].clear()
kb[user_id]["vecs"] = torch.empty((0, 4096))
return {"success": True, "message": f"Cleared KB for user '{user_id}'."}
else:
return {"success": False, "message": "User ID not found."}
# ---------- 5. run both (FastAPI + Gradio) -----------------------------------
if __name__ == "__main__":
# launch Gradio on a background thread
demo.queue().launch(share=False, prevent_thread_lock=True)
# then start FastAPI (uvicorn blocks main thread)
uvicorn.run(api, host="0.0.0.0", port=8000)