Spaces:
Sleeping
Sleeping
| """ | |
| app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query) | |
| """ | |
| # ---------- 1. imports & global helpers ------------- | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import math, torch, uvicorn, gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| AutoTokenizer, AutoModel, AutoConfig | |
| ) | |
| import torch.nn.functional as F | |
| from collections import defaultdict | |
| HF_TOKEN = os.getenv("HF_token") | |
| CHAT_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" | |
| EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1" | |
| MAX_PROMPT_TOKENS = 8192 | |
| # ---------- new defaults & helper ------------------ | |
| DEFAULT_TEMP = 0.7 | |
| DEFAULT_TOP_P = 0.9 | |
| DEFAULT_TOP_K_TOK = 40 # token-level sampling | |
| DEFAULT_CHUNK_SIZE = 512 # characters | |
| DEFAULT_CHUNK_OVERLAP = 128 | |
| def chunk_text(text: str, size: int, overlap: int): | |
| """Yield sliding-window chunks of *text* with character overlap.""" | |
| for start in range(0, len(text), size - overlap): | |
| yield text[start : start + size] | |
| # --- lazy loaders (unchanged) ------------------------------------------------- | |
| tokenizer, chat_model = None, None | |
| emb_tokenizer, emb_model = None, None | |
| def load_chat(): | |
| global tokenizer, chat_model | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN) | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN | |
| ) | |
| def load_embedder(): | |
| global emb_tokenizer, emb_model | |
| if emb_tokenizer is None: | |
| emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN) | |
| cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN) | |
| emb_model = AutoModel.from_pretrained( | |
| EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN | |
| ) | |
| emb_model.eval() | |
| def embed(text:str)->torch.Tensor: | |
| load_embedder() | |
| with torch.no_grad(): | |
| inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device) | |
| vec = emb_model(**inputs).last_hidden_state[:, 0] | |
| return F.normalize(vec, dim=-1).cpu() | |
| # ---------- 2. tiny in-memory KB shared by Gradio & API ---------------------- | |
| # ---------- 2. Tiny in-memory knowledge-base ------------------------------- | |
| # One dict entry per user_id. | |
| # Each entry holds: | |
| # • "texts": list[str] – the raw passages we ingested | |
| # • "vecs" : Tensor[N,d] – their embeddings stacked row-wise | |
| # -------------------------------------------------------------------------- | |
| kb = defaultdict(lambda: {"texts": [], "vecs": None}) | |
| def add_docs(user_id: str,docs: list[str],chunk_size: int = DEFAULT_CHUNK_SIZE,chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> int: | |
| # ---------- NEW ---------- | |
| chunks = [] | |
| for d in docs: | |
| chunks.extend(chunk_text(d, chunk_size, chunk_overlap)) | |
| docs = [c for c in chunks if c.strip()] | |
| load_embedder() # lazy-load once | |
| new_vecs = torch.stack([embed(t) for t in docs]).cpu() | |
| store = kb[user_id] # auto-creates via defaultdict | |
| store["texts"].extend(docs) | |
| store["vecs"] = ( | |
| new_vecs if store["vecs"] is None | |
| else torch.cat([store["vecs"], new_vecs]) | |
| ) | |
| return len(docs) | |
| # ----- Qwen-chat prompt helper --------------------------------------------- | |
| def build_llm_prompt(system: str, context: list[str], user_question: str) -> str: | |
| """ | |
| 建立適用於 LLaMA/Qwen 等模型的 prompt,支援多段 context, | |
| 並強化 system prompt 限制模型僅輸出回應內容。 | |
| """ | |
| load_chat() # 確保 tokenizer 載入 | |
| # 強化指令:防止解釋與步驟 | |
| system_prompt = ( | |
| f"{system.strip()}\n" | |
| "Do not include any explanations, steps, or analysis. " | |
| "Only output the final reply content." | |
| ) | |
| conversation = [ | |
| {"role": "system", "content": system_prompt} | |
| ] | |
| # 多段 context 當作 user 發言 | |
| for ctx in context: | |
| if ctx.strip(): # 忽略空內容 | |
| conversation.append({"role": "user", "content": ctx.strip()}) | |
| # 最後加入使用者問題 | |
| conversation.append({"role": "user", "content": user_question.strip()}) | |
| # 套用 LLaMA-style prompt 格式 | |
| return tokenizer.apply_chat_template( | |
| conversation, | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| # ---------- 4. Gradio playground (same UI as before) -------------------------- | |
| def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP): | |
| try: | |
| n = add_docs(user_id, [doc_text], chunk_size, chunk_overlap) | |
| if n == 0: | |
| return "Nothing stored (empty input)." | |
| return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)." | |
| except Exception as e: | |
| return f"Error during storing: {e}" | |
| import traceback | |
| def answer(system: str, context: str, question: str, | |
| user_id="demo", history="None", | |
| temperature=DEFAULT_TEMP, | |
| top_p=DEFAULT_TOP_P, | |
| top_k_tok=DEFAULT_TOP_K_TOK): | |
| """UI callback: retrieve, build prompt with Qwen tags, generate answer.""" | |
| try: | |
| if not question.strip(): | |
| return "Please ask a question." | |
| if history != "None" and not kb[user_id]["texts"]: | |
| return "No reference passage yet. Add one first." | |
| context_list = [context] | |
| # 1. Retrieve top-k similar passages | |
| if history == "Some": | |
| q_vec = embed(question).view(-1).cpu() | |
| store = kb[user_id] | |
| vecs = store["vecs"] | |
| if vecs is None or vecs.size(0) == 0: | |
| return "Knowledge base is empty or corrupted." | |
| sims = torch.matmul(vecs, q_vec) # [N] | |
| if sims.dim() > 1: | |
| sims = sims.squeeze(1) | |
| k = min(4, sims.size(0)) | |
| idxs = torch.topk(sims, k=k, dim=0).indices.tolist() | |
| context_list += [store["texts"][i] for i in idxs] | |
| elif history == "All": | |
| store = kb[user_id] | |
| context_list += store["texts"] | |
| # 2. Build a Qwen-chat prompt (helper defined earlier) | |
| prompt = build_llm_prompt(system, context_list, question) | |
| # 3. Tokenise & cap | |
| load_chat() | |
| tokens = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| add_special_tokens=False, # we built the chat template ourselves | |
| ) | |
| if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS: | |
| tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()} | |
| tokens = {k: v.to(chat_model.device) for k, v in tokens.items()} | |
| # --- generate ------------------------------------------------------ | |
| output = chat_model.generate( | |
| **tokens, | |
| max_new_tokens=512, | |
| max_length=MAX_PROMPT_TOKENS + 512, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k_tok | |
| ) | |
| full = tokenizer.decode(output[0], skip_special_tokens=False) | |
| start = "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| startwithoutend = "<|start_header_id|>assistant" | |
| end = "<|eot_id|>" | |
| if start in full: | |
| reply = full.split(start)[-1].split(end)[0].strip() | |
| elif startwithoutend in full: | |
| reply = full.split(startwithoutend)[-1].split(end)[0].strip() | |
| else: | |
| reply = full | |
| return reply | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return f"Error in app.py: {tb}, k={k}, sims.numel()={sims.numel()}, sims.shape={sims.shape if 'q_vec' in locals() else 'N/A'}" | |
| finally: | |
| torch.cuda.empty_cache() | |
| def clear_kb(user_id="demo"): | |
| if user_id in kb: | |
| kb[user_id]["texts"].clear() | |
| kb[user_id]["vecs"] = None | |
| return f"Cleared KB for user '{user_id}'." | |
| else: | |
| return f"User ID '{user_id}' not found." | |
| # ---- UI layout (feel free to tweak cosmetics) ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Tiny-RAG playground …") | |
| # ---- passage ingestion ---- | |
| with gr.Row(): | |
| passage_box = gr.Textbox(lines=6, label="Reference passage") | |
| user_id_box = gr.Textbox(value="demo", label="User ID") | |
| chunk_box = gr.Slider(128, 2048, value=DEFAULT_CHUNK_SIZE, | |
| step=64, label="Chunk size (chars)") | |
| overlap_box = gr.Slider(0, 1024, value=DEFAULT_CHUNK_OVERLAP, | |
| step=32, label="Chunk overlap") | |
| store_btn = gr.Button("Store passage") | |
| clear_btn = gr.Button("Clear KB") | |
| status_box = gr.Markdown() # declare *before* wiring handlers | |
| # ---- wire handlers (each button exactly once) ---- | |
| store_btn.click( | |
| fn=store_doc, | |
| inputs=[passage_box, user_id_box, chunk_box, overlap_box], | |
| outputs=status_box | |
| ) | |
| clear_btn.click( | |
| fn=clear_kb, | |
| inputs=user_id_box, | |
| outputs=status_box | |
| ) | |
| # ---------- Q & A ---------- | |
| question_box = gr.Textbox(lines=2, label="Ask a question") | |
| history_cb = gr.Textbox(value="None", label="Use chat history") | |
| system_box = gr.Textbox(lines=2, label="System prompt") | |
| context_box = gr.Textbox(lines=6, label="Context passages") | |
| # NEW sampling sliders | |
| temp_box = gr.Slider(0.0, 1.5, value=DEFAULT_TEMP, | |
| step=0.05, label="Temperature") | |
| topp_box = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, | |
| step=0.01, label="Top-p") | |
| topk_box = gr.Slider(1, 100, value=DEFAULT_TOP_K_TOK, | |
| step=1, label="Top-k (tokens)") | |
| answer_btn = gr.Button("Answer") | |
| answer_box = gr.Textbox(lines=6, label="Assistant reply") | |
| answer_btn.click( | |
| fn=answer, | |
| inputs=[system_box, context_box, question_box, | |
| user_id_box, history_cb, | |
| temp_box, topp_box, topk_box], | |
| outputs=answer_box | |
| ) | |
| # ---------- 3. FastAPI layer -------------------------------------------------- | |
| class IngestReq(BaseModel): | |
| user_id:str | |
| docs:list[str] | |
| class QueryReq(BaseModel): | |
| user_id:str | |
| question:str | |
| api = FastAPI() | |
| api = gr.mount_gradio_app(api, demo, path="/") | |
| # ---------- 5. run both (FastAPI + Gradio) ----------------------------------- | |
| if __name__ == "__main__": | |
| # launch Gradio on a background thread | |
| demo.queue().launch(share=False, prevent_thread_lock=True) | |
| # then start FastAPI (uvicorn blocks main thread) | |
| uvicorn.run(api, host="0.0.0.0", port=8000) |