GmailAddOn / app.py
Cyantist8208's picture
without end
3a2aec0
"""
app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query)
"""
# ---------- 1. imports & global helpers -------------
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math, torch, uvicorn, gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoTokenizer, AutoModel, AutoConfig
)
import torch.nn.functional as F
from collections import defaultdict
HF_TOKEN = os.getenv("HF_token")
CHAT_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
MAX_PROMPT_TOKENS = 8192
# ---------- new defaults & helper ------------------
DEFAULT_TEMP = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K_TOK = 40 # token-level sampling
DEFAULT_CHUNK_SIZE = 512 # characters
DEFAULT_CHUNK_OVERLAP = 128
def chunk_text(text: str, size: int, overlap: int):
"""Yield sliding-window chunks of *text* with character overlap."""
for start in range(0, len(text), size - overlap):
yield text[start : start + size]
# --- lazy loaders (unchanged) -------------------------------------------------
tokenizer, chat_model = None, None
emb_tokenizer, emb_model = None, None
def load_chat():
global tokenizer, chat_model
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN)
chat_model = AutoModelForCausalLM.from_pretrained(
CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN
)
def load_embedder():
global emb_tokenizer, emb_model
if emb_tokenizer is None:
emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
emb_model = AutoModel.from_pretrained(
EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN
)
emb_model.eval()
@torch.no_grad()
def embed(text:str)->torch.Tensor:
load_embedder()
with torch.no_grad():
inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
vec = emb_model(**inputs).last_hidden_state[:, 0]
return F.normalize(vec, dim=-1).cpu()
# ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
# ---------- 2. Tiny in-memory knowledge-base -------------------------------
# One dict entry per user_id.
# Each entry holds:
# • "texts": list[str] – the raw passages we ingested
# • "vecs" : Tensor[N,d] – their embeddings stacked row-wise
# --------------------------------------------------------------------------
kb = defaultdict(lambda: {"texts": [], "vecs": None})
def add_docs(user_id: str,docs: list[str],chunk_size: int = DEFAULT_CHUNK_SIZE,chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> int:
# ---------- NEW ----------
chunks = []
for d in docs:
chunks.extend(chunk_text(d, chunk_size, chunk_overlap))
docs = [c for c in chunks if c.strip()]
load_embedder() # lazy-load once
new_vecs = torch.stack([embed(t) for t in docs]).cpu()
store = kb[user_id] # auto-creates via defaultdict
store["texts"].extend(docs)
store["vecs"] = (
new_vecs if store["vecs"] is None
else torch.cat([store["vecs"], new_vecs])
)
return len(docs)
# ----- Qwen-chat prompt helper ---------------------------------------------
def build_llm_prompt(system: str, context: list[str], user_question: str) -> str:
"""
建立適用於 LLaMA/Qwen 等模型的 prompt,支援多段 context,
並強化 system prompt 限制模型僅輸出回應內容。
"""
load_chat() # 確保 tokenizer 載入
# 強化指令:防止解釋與步驟
system_prompt = (
f"{system.strip()}\n"
"Do not include any explanations, steps, or analysis. "
"Only output the final reply content."
)
conversation = [
{"role": "system", "content": system_prompt}
]
# 多段 context 當作 user 發言
for ctx in context:
if ctx.strip(): # 忽略空內容
conversation.append({"role": "user", "content": ctx.strip()})
# 最後加入使用者問題
conversation.append({"role": "user", "content": user_question.strip()})
# 套用 LLaMA-style prompt 格式
return tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False
)
# ---------- 4. Gradio playground (same UI as before) --------------------------
def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
try:
n = add_docs(user_id, [doc_text], chunk_size, chunk_overlap)
if n == 0:
return "Nothing stored (empty input)."
return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)."
except Exception as e:
return f"Error during storing: {e}"
import traceback
def answer(system: str, context: str, question: str,
user_id="demo", history="None",
temperature=DEFAULT_TEMP,
top_p=DEFAULT_TOP_P,
top_k_tok=DEFAULT_TOP_K_TOK):
"""UI callback: retrieve, build prompt with Qwen tags, generate answer."""
try:
if not question.strip():
return "Please ask a question."
if history != "None" and not kb[user_id]["texts"]:
return "No reference passage yet. Add one first."
context_list = [context]
# 1. Retrieve top-k similar passages
if history == "Some":
q_vec = embed(question).view(-1).cpu()
store = kb[user_id]
vecs = store["vecs"]
if vecs is None or vecs.size(0) == 0:
return "Knowledge base is empty or corrupted."
sims = torch.matmul(vecs, q_vec) # [N]
if sims.dim() > 1:
sims = sims.squeeze(1)
k = min(4, sims.size(0))
idxs = torch.topk(sims, k=k, dim=0).indices.tolist()
context_list += [store["texts"][i] for i in idxs]
elif history == "All":
store = kb[user_id]
context_list += store["texts"]
# 2. Build a Qwen-chat prompt (helper defined earlier)
prompt = build_llm_prompt(system, context_list, question)
# 3. Tokenise & cap
load_chat()
tokens = tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False, # we built the chat template ourselves
)
if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
# --- generate ------------------------------------------------------
output = chat_model.generate(
**tokens,
max_new_tokens=512,
max_length=MAX_PROMPT_TOKENS + 512,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k_tok
)
full = tokenizer.decode(output[0], skip_special_tokens=False)
start = "<|start_header_id|>assistant<|end_header_id|>\n\n"
startwithoutend = "<|start_header_id|>assistant"
end = "<|eot_id|>"
if start in full:
reply = full.split(start)[-1].split(end)[0].strip()
elif startwithoutend in full:
reply = full.split(startwithoutend)[-1].split(end)[0].strip()
else:
reply = full
return reply
except Exception as e:
tb = traceback.format_exc()
return f"Error in app.py: {tb}, k={k}, sims.numel()={sims.numel()}, sims.shape={sims.shape if 'q_vec' in locals() else 'N/A'}"
finally:
torch.cuda.empty_cache()
def clear_kb(user_id="demo"):
if user_id in kb:
kb[user_id]["texts"].clear()
kb[user_id]["vecs"] = None
return f"Cleared KB for user '{user_id}'."
else:
return f"User ID '{user_id}' not found."
# ---- UI layout (feel free to tweak cosmetics) -----------------------------
with gr.Blocks() as demo:
gr.Markdown("### Tiny-RAG playground …")
# ---- passage ingestion ----
with gr.Row():
passage_box = gr.Textbox(lines=6, label="Reference passage")
user_id_box = gr.Textbox(value="demo", label="User ID")
chunk_box = gr.Slider(128, 2048, value=DEFAULT_CHUNK_SIZE,
step=64, label="Chunk size (chars)")
overlap_box = gr.Slider(0, 1024, value=DEFAULT_CHUNK_OVERLAP,
step=32, label="Chunk overlap")
store_btn = gr.Button("Store passage")
clear_btn = gr.Button("Clear KB")
status_box = gr.Markdown() # declare *before* wiring handlers
# ---- wire handlers (each button exactly once) ----
store_btn.click(
fn=store_doc,
inputs=[passage_box, user_id_box, chunk_box, overlap_box],
outputs=status_box
)
clear_btn.click(
fn=clear_kb,
inputs=user_id_box,
outputs=status_box
)
# ---------- Q & A ----------
question_box = gr.Textbox(lines=2, label="Ask a question")
history_cb = gr.Textbox(value="None", label="Use chat history")
system_box = gr.Textbox(lines=2, label="System prompt")
context_box = gr.Textbox(lines=6, label="Context passages")
# NEW sampling sliders
temp_box = gr.Slider(0.0, 1.5, value=DEFAULT_TEMP,
step=0.05, label="Temperature")
topp_box = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P,
step=0.01, label="Top-p")
topk_box = gr.Slider(1, 100, value=DEFAULT_TOP_K_TOK,
step=1, label="Top-k (tokens)")
answer_btn = gr.Button("Answer")
answer_box = gr.Textbox(lines=6, label="Assistant reply")
answer_btn.click(
fn=answer,
inputs=[system_box, context_box, question_box,
user_id_box, history_cb,
temp_box, topp_box, topk_box],
outputs=answer_box
)
# ---------- 3. FastAPI layer --------------------------------------------------
class IngestReq(BaseModel):
user_id:str
docs:list[str]
class QueryReq(BaseModel):
user_id:str
question:str
api = FastAPI()
api = gr.mount_gradio_app(api, demo, path="/")
# ---------- 5. run both (FastAPI + Gradio) -----------------------------------
if __name__ == "__main__":
# launch Gradio on a background thread
demo.queue().launch(share=False, prevent_thread_lock=True)
# then start FastAPI (uvicorn blocks main thread)
uvicorn.run(api, host="0.0.0.0", port=8000)