Spaces:
Sleeping
Sleeping
| """ | |
| app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query) | |
| """ | |
| # ---------- 1. imports & global helpers ------------- | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import math, torch, uvicorn, gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| AutoTokenizer, AutoModel, AutoConfig | |
| ) | |
| import torch.nn.functional as F | |
| from collections import defaultdict | |
| HF_TOKEN = os.getenv("HF_token") | |
| CHAT_MODEL_ID = "QWen/Qwen1.5-7B-Chat" | |
| EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1" | |
| # --- lazy loaders (unchanged) ------------------------------------------------- | |
| tokenizer, chat_model = None, None | |
| emb_tokenizer, emb_model = None, None | |
| def load_chat(): | |
| global tokenizer, chat_model | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN) | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN | |
| ) | |
| def load_embedder(): | |
| global emb_tokenizer, emb_model | |
| if emb_tokenizer is None: | |
| emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN) | |
| cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN) | |
| emb_model = AutoModel.from_pretrained( | |
| EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN | |
| ) | |
| emb_model.eval() | |
| def embed(text:str)->torch.Tensor: | |
| """Return L2-normalised embedding vector.""" | |
| load_embedder() | |
| inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device) | |
| vec = emb_model(**inputs).last_hidden_state[:, 0] # CLS pooling | |
| return F.normalize(vec, dim=-1).squeeze(0) | |
| # ---------- 2. tiny in-memory KB shared by Gradio & API ---------------------- | |
| # ---------- 2. Tiny in-memory knowledge-base ------------------------------- | |
| # One dict entry per user_id. | |
| # Each entry holds: | |
| # • "texts": list[str] – the raw passages we ingested | |
| # • "vecs" : Tensor[N,d] – their embeddings stacked row-wise | |
| # -------------------------------------------------------------------------- | |
| kb = defaultdict(lambda: {"texts": [], "vecs": None}) | |
| def add_docs(user_id: str, docs: list[str]) -> int: | |
| """Embed *docs* and append them to the KB for *user_id*. | |
| Returns the number of docs actually stored.""" | |
| docs = [t for t in docs if t.strip()] # skip blanks | |
| if not docs: | |
| return 0 | |
| load_embedder() # lazy-load once | |
| new_vecs = torch.stack([embed(t) for t in docs]) | |
| store = kb[user_id] # auto-creates via defaultdict | |
| store["texts"].extend(docs) | |
| store["vecs"] = ( | |
| new_vecs if store["vecs"] is None | |
| else torch.cat([store["vecs"], new_vecs]) | |
| ) | |
| return len(docs) | |
| # ----- Qwen-chat prompt helper --------------------------------------------- | |
| def build_qwen_prompt(context: str, user_question: str) -> str: | |
| """Return a string that follows Qwen-Chat’s template.""" | |
| load_chat() # ← make sure tokenizer is ready | |
| if context != "": | |
| context = f"Context:\n{context}\n\n" | |
| conversation = [ | |
| {"role": "system", | |
| "content": "You are an email assistant. Use ONLY the context provided."}, | |
| {"role": "user", | |
| "content": f"{context}{user_question}"} | |
| ] | |
| return tokenizer.apply_chat_template( | |
| conversation, tokenize=False, add_generation_prompt=True | |
| ) | |
| # ---------- 4. Gradio playground (same UI as before) -------------------------- | |
| # ---------- 4. Gradio playground ------------------------------------------ | |
| def store_doc(doc_text: str, user_id="demo"): | |
| """UI callback: take the textbox content and shove it into the KB.""" | |
| n = add_docs(user_id, [doc_text]) | |
| if n == 0: | |
| return "Nothing stored (empty input)." | |
| return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)." | |
| def answer(question: str, user_id="demo", history="None"): | |
| """UI callback: retrieve, build prompt with Qwen tags, generate answer.""" | |
| try: | |
| if not question.strip(): | |
| return "Please ask a question." | |
| if history and not kb[user_id]["texts"]: | |
| return "No reference passage yet. Add one first." | |
| context = "" | |
| # 1. Retrieve top-k similar passages | |
| if history == "Some": | |
| q_vec = embed(question) | |
| store = kb[user_id] | |
| sims = torch.matmul(store["vecs"], q_vec) # [N] | |
| k = min(4, sims.numel()) | |
| idxs = torch.topk(sims, k=k).indices.tolist() | |
| context = "\n".join(store["texts"][i] for i in idxs) | |
| elif history == "All": | |
| context = "\n".join(store["texts"]) | |
| # 2. Build a Qwen-chat prompt (helper defined earlier) | |
| prompt = build_qwen_prompt(context, question) | |
| # 3. Generate and strip everything before the assistant tag | |
| load_chat() | |
| inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device) | |
| output = chat_model.generate(**inputs, max_new_tokens=512) | |
| full = tokenizer.decode(output[0], skip_special_tokens=True) | |
| reply = full.split("<|im_start|>assistant")[-1].strip() | |
| return reply | |
| except Exception as e: | |
| return f"Error in app.py: {e}" | |
| finally: | |
| torch.cuda.empty_cache() | |
| # ---- UI layout (feel free to tweak cosmetics) ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Tiny-RAG playground – 1) paste a passage → store 2) ask a question") | |
| # ---- passage ingestion ---- | |
| with gr.Row(): | |
| passage_box = gr.Textbox(lines=6, label="Reference passage") | |
| store_btn = gr.Button("Store passage") | |
| status_box = gr.Markdown() | |
| store_btn.click(fn=store_doc, | |
| inputs=passage_box, | |
| outputs=status_box) | |
| # ---- Q & A ---- | |
| question_box = gr.Textbox(lines=2, label="Ask a question") | |
| user_id_box = gr.Textbox(value="demo", label="User ID") | |
| history_cb = gr.Textbox(value="None", label="Use chat history") | |
| answer_btn = gr.Button("Answer") | |
| answer_box = gr.Textbox(lines=6, label="Assistant reply") | |
| answer_btn.click( | |
| fn=answer, | |
| inputs=[question_box, user_id_box, history_cb], | |
| outputs=answer_box | |
| ) | |
| # ---------- 3. FastAPI layer -------------------------------------------------- | |
| class IngestReq(BaseModel): | |
| user_id:str | |
| docs:list[str] | |
| class QueryReq(BaseModel): | |
| user_id:str | |
| question:str | |
| api = FastAPI() | |
| api = gr.mount_gradio_app(api, demo, path="/") | |
| def ingest(req:IngestReq): | |
| load_embedder() | |
| vecs = torch.stack([embed(t) for t in req.docs]) | |
| store = kb.setdefault(req.user_id, {"texts":[], "vecs":None}) | |
| store["texts"].extend(req.docs) | |
| store["vecs"] = vecs if store["vecs"] is None else torch.cat([store["vecs"], vecs]) | |
| return {"added": len(req.docs)} | |
| def rag(req:QueryReq): | |
| store = kb.get(req.user_id) | |
| if not store: | |
| raise HTTPException(404, "No knowledge ingested for this user.") | |
| q_vec = embed(req.question) | |
| sims = torch.matmul(store["vecs"], q_vec) | |
| topk = torch.topk(sims, k=min(4, sims.size(0))).indices | |
| context = "\n".join(store["texts"][i] for i in topk.tolist()) | |
| prompt = build_qwen_prompt(context, req.question) | |
| load_chat() | |
| inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device) | |
| out = chat_model.generate(**inputs, max_new_tokens=512) | |
| full = tokenizer.decode(out[0], skip_special_tokens=True) | |
| ans = full.split("<|im_start|>assistant")[-1].strip() | |
| return {"answer": ans} | |
| def clear_kb(user_id: str = "demo"): | |
| if user_id in kb: | |
| kb[user_id]["texts"].clear() | |
| kb[user_id]["vecs"] = torch.empty((0, 4096)) | |
| return {"success": True, "message": f"Cleared KB for user '{user_id}'."} | |
| else: | |
| return {"success": False, "message": "User ID not found."} | |
| # ---------- 5. run both (FastAPI + Gradio) ----------------------------------- | |
| if __name__ == "__main__": | |
| # launch Gradio on a background thread | |
| demo.queue().launch(share=False, prevent_thread_lock=True) | |
| # then start FastAPI (uvicorn blocks main thread) | |
| uvicorn.run(api, host="0.0.0.0", port=8000) | |