Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import pickle | |
| import threading | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from rank_bm25 import BM25Okapi | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| import gradio as gr | |
| # ---------------------------- | |
| # Config (match your notebook) | |
| # ---------------------------- | |
| EMBED_MODEL_NAME = "intfloat/multilingual-e5-large" # notebook uses this:contentReference[oaicite:4]{index=4} | |
| LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # notebook uses this:contentReference[oaicite:5]{index=5} | |
| CHUNKS_PATH = "sharif_rules_chunked.json" | |
| FAISS_PATH = "vector_index.faiss" # pickle-dumped faiss index in notebook:contentReference[oaicite:6]{index=6} | |
| BM25_PATH = "bm25_index.pkl" # pickle-dumped bm25 in notebook:contentReference[oaicite:7]{index=7} | |
| # You used k up to 6 in the UI in notebook | |
| DEFAULT_K = 3 | |
| DEFAULT_MAX_CTX_CHARS = 1200 | |
| # ---------------------------- | |
| # Load artifacts | |
| # ---------------------------- | |
| def load_artifacts(): | |
| if not os.path.exists(CHUNKS_PATH): | |
| raise FileNotFoundError( | |
| f"Missing {CHUNKS_PATH}. Upload it to the Space repo (recommended), " | |
| "or add code to build it at startup." | |
| ) | |
| if not os.path.exists(FAISS_PATH) or not os.path.exists(BM25_PATH): | |
| raise FileNotFoundError( | |
| f"Missing {FAISS_PATH} and/or {BM25_PATH}. Upload them to the Space repo." | |
| ) | |
| with open(CHUNKS_PATH, "r", encoding="utf-8") as f: | |
| chunks = json.load(f) | |
| with open(FAISS_PATH, "rb") as f: | |
| vector_index = pickle.load(f) | |
| with open(BM25_PATH, "rb") as f: | |
| bm25 = pickle.load(f) | |
| return chunks, vector_index, bm25 | |
| print("Loading embedding model...") | |
| embed_model = SentenceTransformer(EMBED_MODEL_NAME) | |
| print("Loading retrieval artifacts...") | |
| chunks, vector_index, bm25 = load_artifacts() | |
| print("Loading LLM + tokenizer...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| print("All models loaded.") | |
| # ---------------------------- | |
| # Retrieval (match notebook) | |
| # ---------------------------- | |
| def hybrid_search(query: str, k: int = 5): | |
| """ | |
| Hybrid Search (Vector + BM25) with Reciprocal Rank Fusion, same logic as notebook. | |
| """ | |
| # 1) Vector search | |
| query_embedding = embed_model.encode([query], normalize_embeddings=True) | |
| v_scores, v_indices = vector_index.search(query_embedding, k) | |
| # 2) BM25 search | |
| tokenized_query = query.split() | |
| bm25_scores = bm25.get_scores(tokenized_query) | |
| bm25_indices = np.argsort(bm25_scores)[::-1][:k] | |
| # 3) RRF fusion | |
| fusion_scores = {} | |
| for rank, idx in enumerate(v_indices[0]): | |
| fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60) | |
| for rank, idx in enumerate(bm25_indices): | |
| fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60) | |
| sorted_indices = sorted(fusion_scores, key=fusion_scores.get, reverse=True)[:k] | |
| return [chunks[i] for i in sorted_indices] | |
| # ---------------------------- | |
| # Prompt + generation | |
| # ---------------------------- | |
| SYSTEM_PROMPT_FA = """شما یک دستیار هوشمند آموزشی برای دانشگاه صنعتی شریف هستید. | |
| وظیفه شما پاسخدهی دقیق به سوالات دانشجو بر اساس "متن قوانین" زیر است. | |
| قوانین مهم: | |
| 1. فقط و فقط از اطلاعات موجود در بخش [Context] استفاده کنید. از دانش قبلی خود استفاده نکنید. | |
| 2. اگر پاسخ سوال در متن موجود نیست، دقیقاً بگویید: "اطلاعاتی در این مورد در آییننامههای موجود یافت نشد." | |
| 3. پاسخ نهایی باید کاملاً به زبان فارسی باشد. | |
| 4. نام آییننامه و شماره ماده یا تبصره را در پاسخ ذکر کنید. | |
| """ | |
| def build_context_text(retrieved_chunks, max_ctx_chars: int): | |
| context_text = "" | |
| for i, chunk in enumerate(retrieved_chunks): | |
| # Your notebook stores metadata in chunk["metadata"] with title/article:contentReference[oaicite:8]{index=8}:contentReference[oaicite:9]{index=9} | |
| md = chunk.get("metadata", {}) or {} | |
| source = md.get("title", "Unknown") | |
| article = md.get("article", "N/A") | |
| txt = (chunk.get("text", "") or "").strip() | |
| txt = txt[: int(max_ctx_chars)] | |
| context_text += f"Document {i+1} (Source: {source}, Article: {article}):\n{txt}\n\n" | |
| return context_text | |
| def generate_answer_stream(query: str, retrieved_chunks, max_ctx_chars: int = 1200): | |
| """ | |
| True token streaming with TextIteratorStreamer. | |
| Yields partial strings (the growing answer). | |
| """ | |
| context_text = build_context_text(retrieved_chunks, max_ctx_chars=max_ctx_chars) | |
| user_prompt = f"""سوال: {query} | |
| [Context]: | |
| {context_text} | |
| پاسخ:""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_FA}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_special_tokens=True, | |
| # keep prompt out of the stream (we only want the assistant answer) | |
| skip_prompt=True, | |
| ) | |
| gen_kwargs = dict( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| temperature=0.1, | |
| top_p=0.9, | |
| streamer=streamer, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| partial = "" | |
| for token_text in streamer: | |
| partial += token_text | |
| yield partial | |
| thread.join() | |
| # ---------------------------- | |
| # UI helpers (match your demo) | |
| # ---------------------------- | |
| def format_sources(retrieved_docs, max_chars=300): | |
| lines = [] | |
| for i, d in enumerate(retrieved_docs, 1): | |
| md = d.get("metadata", {}) or {} | |
| title = md.get("title", "") | |
| src = md.get("source", "") | |
| art = md.get("article", "-") | |
| snippet = (d.get("text", "") or "").strip().replace("\n", " ") | |
| snippet = snippet[:max_chars] + ("…" if len(snippet) > max_chars else "") | |
| lines.append(f"{i}. {title}\n source: {src} | ماده: {art}\n snippet: {snippet}") | |
| return "\n\n".join(lines) | |
| def rag_answer_ui_stream(question, k, max_ctx_chars): | |
| if not question or not question.strip(): | |
| yield "لطفاً سوال را وارد کنید.", "" | |
| return | |
| # 1) Retrieve | |
| retrieved = hybrid_search(question, k=int(k)) | |
| if not retrieved: | |
| yield "اطلاعاتی در این مورد در آییننامههای موجود یافت نشد.", "" | |
| return | |
| # 2) Prepare sources (static; we keep showing it while streaming) | |
| sources_text = format_sources(retrieved) | |
| # 3) Stream answer | |
| for partial_answer in generate_answer_stream( | |
| question, | |
| retrieved, | |
| max_ctx_chars=int(max_ctx_chars), | |
| ): | |
| yield partial_answer, sources_text | |
| with gr.Blocks(title="Sharif RAG Demo (Streaming)") as demo: | |
| gr.Markdown( | |
| "## 🎓 Sharif Regulations RAG Bot (Streaming)\n" | |
| "سوال خود را وارد کنید. پاسخ فقط بر اساس متنهای بازیابیشده تولید میشود." | |
| ) | |
| with gr.Row(): | |
| question = gr.Textbox( | |
| label="❓ Question (Persian)", | |
| placeholder="مثلاً: شرایط مهمانی در دوره روزانه؟", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| k = gr.Slider(1, 6, value=DEFAULT_K, step=1, label="🔎 Number of retrieved chunks (k)") | |
| max_ctx_chars = gr.Slider(300, 2500, value=DEFAULT_MAX_CTX_CHARS, step=100, label="✂️ Max chars per chunk (for generation)") | |
| run_btn = gr.Button("Run RAG (stream)") | |
| answer_out = gr.Textbox(label="🤖 Answer (streaming)", lines=10) | |
| sources_out = gr.Textbox(label="📚 Retrieved sources (debug)", lines=12) | |
| run_btn.click( | |
| fn=rag_answer_ui_stream, | |
| inputs=[question, k, max_ctx_chars], | |
| outputs=[answer_out, sources_out], | |
| ) | |
| # Spaces will call app.py; server_name makes it work in containers too | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |