Spaces:
Sleeping
Sleeping
File size: 10,905 Bytes
2756958 a16bd28 2756958 bd703a0 2756958 fef5c81 2756958 f1365f3 2756958 4872cd0 2756958 fef5c81 2756958 f1365f3 2756958 f1365f3 2756958 fef5c81 2756958 71db3d7 6b79d5e 5dce132 acbd024 5dce132 71db3d7 5dce132 71db3d7 acbd024 5dce132 acbd024 5dce132 acbd024 5dce132 acbd024 5dce132 c0c0f5a 127e305 c0c0f5a 127e305 2756958 f1365f3 bf671c6 f1365f3 bf671c6 2756958 e53b387 f1365f3 71db3d7 eeca4c9 0c3749b eeca4c9 d1c70d1 2d16ae8 523f197 d94f39b 2d16ae8 e53b387 905b1f8 e53b387 905b1f8 d1c70d1 523f197 8f9cb9f d1c70d1 eeca4c9 2d16ae8 c0c0f5a fef5c81 50e96a1 d3dc9d5 50e96a1 f1365f3 50e96a1 7fba31f 3a2aec0 7fba31f 3a2aec0 7fba31f 50e96a1 eeca4c9 e53b387 55a8a56 2756958 bde046e e501608 bde046e 2756958 1f3364e 2756958 1f3364e 2756958 42b7198 1f3364e 42b7198 1f3364e bde046e 1f3364e 42b7198 bde046e 42b7198 2756958 42b7198 2756958 f1365f3 42b7198 f1365f3 42b7198 2756958 523f197 d1c70d1 f1365f3 4bfdc55 2756958 76f07e5 06c38b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 | """
app.py – Tiny-RAG (Gradio playground) + REST API (/ingest, /query)
"""
# ---------- 1. imports & global helpers -------------
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math, torch, uvicorn, gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoTokenizer, AutoModel, AutoConfig
)
import torch.nn.functional as F
from collections import defaultdict
HF_TOKEN = os.getenv("HF_token")
CHAT_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
MAX_PROMPT_TOKENS = 8192
# ---------- new defaults & helper ------------------
DEFAULT_TEMP = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K_TOK = 40 # token-level sampling
DEFAULT_CHUNK_SIZE = 512 # characters
DEFAULT_CHUNK_OVERLAP = 128
def chunk_text(text: str, size: int, overlap: int):
"""Yield sliding-window chunks of *text* with character overlap."""
for start in range(0, len(text), size - overlap):
yield text[start : start + size]
# --- lazy loaders (unchanged) -------------------------------------------------
tokenizer, chat_model = None, None
emb_tokenizer, emb_model = None, None
def load_chat():
global tokenizer, chat_model
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_ID, token=HF_TOKEN)
chat_model = AutoModelForCausalLM.from_pretrained(
CHAT_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, token=HF_TOKEN
)
def load_embedder():
global emb_tokenizer, emb_model
if emb_tokenizer is None:
emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
cfg = AutoConfig.from_pretrained(EMB_MODEL_ID, token=HF_TOKEN)
emb_model = AutoModel.from_pretrained(
EMB_MODEL_ID, device_map="auto", torch_dtype=torch.float16, config=cfg, token=HF_TOKEN
)
emb_model.eval()
@torch.no_grad()
def embed(text:str)->torch.Tensor:
load_embedder()
with torch.no_grad():
inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
vec = emb_model(**inputs).last_hidden_state[:, 0]
return F.normalize(vec, dim=-1).cpu()
# ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
# ---------- 2. Tiny in-memory knowledge-base -------------------------------
# One dict entry per user_id.
# Each entry holds:
# • "texts": list[str] – the raw passages we ingested
# • "vecs" : Tensor[N,d] – their embeddings stacked row-wise
# --------------------------------------------------------------------------
kb = defaultdict(lambda: {"texts": [], "vecs": None})
def add_docs(user_id: str,docs: list[str],chunk_size: int = DEFAULT_CHUNK_SIZE,chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> int:
# ---------- NEW ----------
chunks = []
for d in docs:
chunks.extend(chunk_text(d, chunk_size, chunk_overlap))
docs = [c for c in chunks if c.strip()]
load_embedder() # lazy-load once
new_vecs = torch.stack([embed(t) for t in docs]).cpu()
store = kb[user_id] # auto-creates via defaultdict
store["texts"].extend(docs)
store["vecs"] = (
new_vecs if store["vecs"] is None
else torch.cat([store["vecs"], new_vecs])
)
return len(docs)
# ----- Qwen-chat prompt helper ---------------------------------------------
def build_llm_prompt(system: str, context: list[str], user_question: str) -> str:
"""
建立適用於 LLaMA/Qwen 等模型的 prompt,支援多段 context,
並強化 system prompt 限制模型僅輸出回應內容。
"""
load_chat() # 確保 tokenizer 載入
# 強化指令:防止解釋與步驟
system_prompt = (
f"{system.strip()}\n"
"Do not include any explanations, steps, or analysis. "
"Only output the final reply content."
)
conversation = [
{"role": "system", "content": system_prompt}
]
# 多段 context 當作 user 發言
for ctx in context:
if ctx.strip(): # 忽略空內容
conversation.append({"role": "user", "content": ctx.strip()})
# 最後加入使用者問題
conversation.append({"role": "user", "content": user_question.strip()})
# 套用 LLaMA-style prompt 格式
return tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=False
)
# ---------- 4. Gradio playground (same UI as before) --------------------------
def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
try:
n = add_docs(user_id, [doc_text], chunk_size, chunk_overlap)
if n == 0:
return "Nothing stored (empty input)."
return f"Stored — KB now has {len(kb[user_id]['texts'])} passage(s)."
except Exception as e:
return f"Error during storing: {e}"
import traceback
def answer(system: str, context: str, question: str,
user_id="demo", history="None",
temperature=DEFAULT_TEMP,
top_p=DEFAULT_TOP_P,
top_k_tok=DEFAULT_TOP_K_TOK):
"""UI callback: retrieve, build prompt with Qwen tags, generate answer."""
try:
if not question.strip():
return "Please ask a question."
if history != "None" and not kb[user_id]["texts"]:
return "No reference passage yet. Add one first."
context_list = [context]
# 1. Retrieve top-k similar passages
if history == "Some":
q_vec = embed(question).view(-1).cpu()
store = kb[user_id]
vecs = store["vecs"]
if vecs is None or vecs.size(0) == 0:
return "Knowledge base is empty or corrupted."
sims = torch.matmul(vecs, q_vec) # [N]
if sims.dim() > 1:
sims = sims.squeeze(1)
k = min(4, sims.size(0))
idxs = torch.topk(sims, k=k, dim=0).indices.tolist()
context_list += [store["texts"][i] for i in idxs]
elif history == "All":
store = kb[user_id]
context_list += store["texts"]
# 2. Build a Qwen-chat prompt (helper defined earlier)
prompt = build_llm_prompt(system, context_list, question)
# 3. Tokenise & cap
load_chat()
tokens = tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False, # we built the chat template ourselves
)
if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
# --- generate ------------------------------------------------------
output = chat_model.generate(
**tokens,
max_new_tokens=512,
max_length=MAX_PROMPT_TOKENS + 512,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k_tok
)
full = tokenizer.decode(output[0], skip_special_tokens=False)
start = "<|start_header_id|>assistant<|end_header_id|>\n\n"
startwithoutend = "<|start_header_id|>assistant"
end = "<|eot_id|>"
if start in full:
reply = full.split(start)[-1].split(end)[0].strip()
elif startwithoutend in full:
reply = full.split(startwithoutend)[-1].split(end)[0].strip()
else:
reply = full
return reply
except Exception as e:
tb = traceback.format_exc()
return f"Error in app.py: {tb}, k={k}, sims.numel()={sims.numel()}, sims.shape={sims.shape if 'q_vec' in locals() else 'N/A'}"
finally:
torch.cuda.empty_cache()
def clear_kb(user_id="demo"):
if user_id in kb:
kb[user_id]["texts"].clear()
kb[user_id]["vecs"] = None
return f"Cleared KB for user '{user_id}'."
else:
return f"User ID '{user_id}' not found."
# ---- UI layout (feel free to tweak cosmetics) -----------------------------
with gr.Blocks() as demo:
gr.Markdown("### Tiny-RAG playground …")
# ---- passage ingestion ----
with gr.Row():
passage_box = gr.Textbox(lines=6, label="Reference passage")
user_id_box = gr.Textbox(value="demo", label="User ID")
chunk_box = gr.Slider(128, 2048, value=DEFAULT_CHUNK_SIZE,
step=64, label="Chunk size (chars)")
overlap_box = gr.Slider(0, 1024, value=DEFAULT_CHUNK_OVERLAP,
step=32, label="Chunk overlap")
store_btn = gr.Button("Store passage")
clear_btn = gr.Button("Clear KB")
status_box = gr.Markdown() # declare *before* wiring handlers
# ---- wire handlers (each button exactly once) ----
store_btn.click(
fn=store_doc,
inputs=[passage_box, user_id_box, chunk_box, overlap_box],
outputs=status_box
)
clear_btn.click(
fn=clear_kb,
inputs=user_id_box,
outputs=status_box
)
# ---------- Q & A ----------
question_box = gr.Textbox(lines=2, label="Ask a question")
history_cb = gr.Textbox(value="None", label="Use chat history")
system_box = gr.Textbox(lines=2, label="System prompt")
context_box = gr.Textbox(lines=6, label="Context passages")
# NEW sampling sliders
temp_box = gr.Slider(0.0, 1.5, value=DEFAULT_TEMP,
step=0.05, label="Temperature")
topp_box = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P,
step=0.01, label="Top-p")
topk_box = gr.Slider(1, 100, value=DEFAULT_TOP_K_TOK,
step=1, label="Top-k (tokens)")
answer_btn = gr.Button("Answer")
answer_box = gr.Textbox(lines=6, label="Assistant reply")
answer_btn.click(
fn=answer,
inputs=[system_box, context_box, question_box,
user_id_box, history_cb,
temp_box, topp_box, topk_box],
outputs=answer_box
)
# ---------- 3. FastAPI layer --------------------------------------------------
class IngestReq(BaseModel):
user_id:str
docs:list[str]
class QueryReq(BaseModel):
user_id:str
question:str
api = FastAPI()
api = gr.mount_gradio_app(api, demo, path="/")
# ---------- 5. run both (FastAPI + Gradio) -----------------------------------
if __name__ == "__main__":
# launch Gradio on a background thread
demo.queue().launch(share=False, prevent_thread_lock=True)
# then start FastAPI (uvicorn blocks main thread)
uvicorn.run(api, host="0.0.0.0", port=8000) |