""" app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query) """ # ---------- 1. imports & global helpers ------------- import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import math, torch, uvicorn, gradio as gr from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig ) import torch.nn.functional as F from collections import defaultdict HF_TOKEN = os.getenv("HF_token") CHAT_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1" MAX_PROMPT_TOKENS = 8192 # ---------- new defaults & helper ------------------ DEFAULT_TEMP = 0.7 DEFAULT_TOP_P = 0.9 DEFAULT_TOP_K_TOK = 40 # token-level sampling DEFAULT_CHUNK_SIZE = 512 # characters DEFAULT_CHUNK_OVERLAP = 128 def chunk_text(text: str, size: int, overlap: int): """Yield sliding-window chunks of *text* with character overlap.""" for start in range(0, len(text), size - overlap): yield text[start : start + size] # --- lazy loaders (unchanged) ------------------------------------------------- tokenizer, chat_model = None, None emb_tokenizer, emb_model = None, None def load_chat(): global tokenizer, chat_model if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN) chat_model = AutoModelForCausalLM.from_pretrained( CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN ) def load_embedder(): global emb_tokenizer, emb_model if emb_tokenizer is None: emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN) cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN) emb_model = AutoModel.from_pretrained( EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN ) emb_model.eval() @torch.no_grad() def embed(text:str)->torch.Tensor: load_embedder() with torch.no_grad(): inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device) vec = emb_model(**inputs).last_hidden_state[:, 0] return F.normalize(vec, dim=-1).cpu() # ---------- 2. tiny in-memory KB shared by Gradio & API ---------------------- # ---------- 2. Tiny in-memory knowledge-base ------------------------------- # One dict entry per user_id. # Each entry holds: # • "texts": list[str] – the raw passages we ingested # • "vecs" : Tensor[N,d] – their embeddings stacked row-wise # -------------------------------------------------------------------------- kb = defaultdict(lambda: {"texts": [], "vecs": None}) def add_docs(user_id: str,docs: list[str],chunk_size: int = DEFAULT_CHUNK_SIZE,chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> int: # ---------- NEW ---------- chunks = [] for d in docs: chunks.extend(chunk_text(d, chunk_size, chunk_overlap)) docs = [c for c in chunks if c.strip()] load_embedder() # lazy-load once new_vecs = torch.stack([embed(t) for t in docs]).cpu() store = kb[user_id] # auto-creates via defaultdict store["texts"].extend(docs) store["vecs"] = ( new_vecs if store["vecs"] is None else torch.cat([store["vecs"], new_vecs]) ) return len(docs) # ----- Qwen-chat prompt helper --------------------------------------------- def build_llm_prompt(system: str, context: list[str], user_question: str) -> str: """ 建立適用於 LLaMA/Qwen 等模型的 prompt,支援多段 context, 並強化 system prompt 限制模型僅輸出回應內容。 """ load_chat() # 確保 tokenizer 載入 # 強化指令:防止解釋與步驟 system_prompt = ( f"{system.strip()}\n" "Do not include any explanations, steps, or analysis. " "Only output the final reply content." ) conversation = [ {"role": "system", "content": system_prompt} ] # 多段 context 當作 user 發言 for ctx in context: if ctx.strip(): # 忽略空內容 conversation.append({"role": "user", "content": ctx.strip()}) # 最後加入使用者問題 conversation.append({"role": "user", "content": user_question.strip()}) # 套用 LLaMA-style prompt 格式 return tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=False ) # ---------- 4. Gradio playground (same UI as before) -------------------------- def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP): try: n = add_docs(user_id, [doc_text], chunk_size, chunk_overlap) if n == 0: return "Nothing stored (empty input)." return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)." except Exception as e: return f"Error during storing: {e}" import traceback def answer(system: str, context: str, question: str, user_id="demo", history="None", temperature=DEFAULT_TEMP, top_p=DEFAULT_TOP_P, top_k_tok=DEFAULT_TOP_K_TOK): """UI callback: retrieve, build prompt with Qwen tags, generate answer.""" try: if not question.strip(): return "Please ask a question." if history != "None" and not kb[user_id]["texts"]: return "No reference passage yet. Add one first." context_list = [context] # 1. Retrieve top-k similar passages if history == "Some": q_vec = embed(question).view(-1).cpu() store = kb[user_id] vecs = store["vecs"] if vecs is None or vecs.size(0) == 0: return "Knowledge base is empty or corrupted." sims = torch.matmul(vecs, q_vec) # [N] if sims.dim() > 1: sims = sims.squeeze(1) k = min(4, sims.size(0)) idxs = torch.topk(sims, k=k, dim=0).indices.tolist() context_list += [store["texts"][i] for i in idxs] elif history == "All": store = kb[user_id] context_list += store["texts"] # 2. Build a Qwen-chat prompt (helper defined earlier) prompt = build_llm_prompt(system, context_list, question) # 3. Tokenise & cap load_chat() tokens = tokenizer( prompt, return_tensors="pt", add_special_tokens=False, # we built the chat template ourselves ) if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS: tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()} tokens = {k: v.to(chat_model.device) for k, v in tokens.items()} # --- generate ------------------------------------------------------ output = chat_model.generate( **tokens, max_new_tokens=512, max_length=MAX_PROMPT_TOKENS + 512, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k_tok ) full = tokenizer.decode(output[0], skip_special_tokens=False) start = "<|start_header_id|>assistant<|end_header_id|>\n\n" startwithoutend = "<|start_header_id|>assistant" end = "<|eot_id|>" if start in full: reply = full.split(start)[-1].split(end)[0].strip() elif startwithoutend in full: reply = full.split(startwithoutend)[-1].split(end)[0].strip() else: reply = full return reply except Exception as e: tb = traceback.format_exc() return f"Error in app.py: {tb}, k={k}, sims.numel()={sims.numel()}, sims.shape={sims.shape if 'q_vec' in locals() else 'N/A'}" finally: torch.cuda.empty_cache() def clear_kb(user_id="demo"): if user_id in kb: kb[user_id]["texts"].clear() kb[user_id]["vecs"] = None return f"Cleared KB for user '{user_id}'." else: return f"User ID '{user_id}' not found." # ---- UI layout (feel free to tweak cosmetics) ----------------------------- with gr.Blocks() as demo: gr.Markdown("### Tiny-RAG playground …") # ---- passage ingestion ---- with gr.Row(): passage_box = gr.Textbox(lines=6, label="Reference passage") user_id_box = gr.Textbox(value="demo", label="User ID") chunk_box = gr.Slider(128, 2048, value=DEFAULT_CHUNK_SIZE, step=64, label="Chunk size (chars)") overlap_box = gr.Slider(0, 1024, value=DEFAULT_CHUNK_OVERLAP, step=32, label="Chunk overlap") store_btn = gr.Button("Store passage") clear_btn = gr.Button("Clear KB") status_box = gr.Markdown() # declare *before* wiring handlers # ---- wire handlers (each button exactly once) ---- store_btn.click( fn=store_doc, inputs=[passage_box, user_id_box, chunk_box, overlap_box], outputs=status_box ) clear_btn.click( fn=clear_kb, inputs=user_id_box, outputs=status_box ) # ---------- Q & A ---------- question_box = gr.Textbox(lines=2, label="Ask a question") history_cb = gr.Textbox(value="None", label="Use chat history") system_box = gr.Textbox(lines=2, label="System prompt") context_box = gr.Textbox(lines=6, label="Context passages") # NEW sampling sliders temp_box = gr.Slider(0.0, 1.5, value=DEFAULT_TEMP, step=0.05, label="Temperature") topp_box = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="Top-p") topk_box = gr.Slider(1, 100, value=DEFAULT_TOP_K_TOK, step=1, label="Top-k (tokens)") answer_btn = gr.Button("Answer") answer_box = gr.Textbox(lines=6, label="Assistant reply") answer_btn.click( fn=answer, inputs=[system_box, context_box, question_box, user_id_box, history_cb, temp_box, topp_box, topk_box], outputs=answer_box ) # ---------- 3. FastAPI layer -------------------------------------------------- class IngestReq(BaseModel): user_id:str docs:list[str] class QueryReq(BaseModel): user_id:str question:str api = FastAPI() api = gr.mount_gradio_app(api, demo, path="/") # ---------- 5. run both (FastAPI + Gradio) ----------------------------------- if __name__ == "__main__": # launch Gradio on a background thread demo.queue().launch(share=False, prevent_thread_lock=True) # then start FastAPI (uvicorn blocks main thread) uvicorn.run(api, host="0.0.0.0", port=8000)