Sharif_RAG_Bot / app.py
aminhalvaei's picture
Create app.py
7b05361 verified
import os
import json
import pickle
import threading
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TextIteratorStreamer,
)
import gradio as gr
# ----------------------------
# Config (match your notebook)
# ----------------------------
EMBED_MODEL_NAME = "intfloat/multilingual-e5-large" # notebook uses this:contentReference[oaicite:4]{index=4}
LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # notebook uses this:contentReference[oaicite:5]{index=5}
CHUNKS_PATH = "sharif_rules_chunked.json"
FAISS_PATH = "vector_index.faiss" # pickle-dumped faiss index in notebook:contentReference[oaicite:6]{index=6}
BM25_PATH = "bm25_index.pkl" # pickle-dumped bm25 in notebook:contentReference[oaicite:7]{index=7}
# You used k up to 6 in the UI in notebook
DEFAULT_K = 3
DEFAULT_MAX_CTX_CHARS = 1200
# ----------------------------
# Load artifacts
# ----------------------------
def load_artifacts():
if not os.path.exists(CHUNKS_PATH):
raise FileNotFoundError(
f"Missing {CHUNKS_PATH}. Upload it to the Space repo (recommended), "
"or add code to build it at startup."
)
if not os.path.exists(FAISS_PATH) or not os.path.exists(BM25_PATH):
raise FileNotFoundError(
f"Missing {FAISS_PATH} and/or {BM25_PATH}. Upload them to the Space repo."
)
with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
chunks = json.load(f)
with open(FAISS_PATH, "rb") as f:
vector_index = pickle.load(f)
with open(BM25_PATH, "rb") as f:
bm25 = pickle.load(f)
return chunks, vector_index, bm25
print("Loading embedding model...")
embed_model = SentenceTransformer(EMBED_MODEL_NAME)
print("Loading retrieval artifacts...")
chunks, vector_index, bm25 = load_artifacts()
print("Loading LLM + tokenizer...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model.eval()
print("All models loaded.")
# ----------------------------
# Retrieval (match notebook)
# ----------------------------
def hybrid_search(query: str, k: int = 5):
"""
Hybrid Search (Vector + BM25) with Reciprocal Rank Fusion, same logic as notebook.
"""
# 1) Vector search
query_embedding = embed_model.encode([query], normalize_embeddings=True)
v_scores, v_indices = vector_index.search(query_embedding, k)
# 2) BM25 search
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_indices = np.argsort(bm25_scores)[::-1][:k]
# 3) RRF fusion
fusion_scores = {}
for rank, idx in enumerate(v_indices[0]):
fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60)
for rank, idx in enumerate(bm25_indices):
fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60)
sorted_indices = sorted(fusion_scores, key=fusion_scores.get, reverse=True)[:k]
return [chunks[i] for i in sorted_indices]
# ----------------------------
# Prompt + generation
# ----------------------------
SYSTEM_PROMPT_FA = """شما یک دستیار هوشمند آموزشی برای دانشگاه صنعتی شریف هستید.
وظیفه شما پاسخ‌دهی دقیق به سوالات دانشجو بر اساس "متن قوانین" زیر است.
قوانین مهم:
1. فقط و فقط از اطلاعات موجود در بخش [Context] استفاده کنید. از دانش قبلی خود استفاده نکنید.
2. اگر پاسخ سوال در متن موجود نیست، دقیقاً بگویید: "اطلاعاتی در این مورد در آیین‌نامه‌های موجود یافت نشد."
3. پاسخ نهایی باید کاملاً به زبان فارسی باشد.
4. نام آیین‌نامه و شماره ماده یا تبصره را در پاسخ ذکر کنید.
"""
def build_context_text(retrieved_chunks, max_ctx_chars: int):
context_text = ""
for i, chunk in enumerate(retrieved_chunks):
# Your notebook stores metadata in chunk["metadata"] with title/article:contentReference[oaicite:8]{index=8}:contentReference[oaicite:9]{index=9}
md = chunk.get("metadata", {}) or {}
source = md.get("title", "Unknown")
article = md.get("article", "N/A")
txt = (chunk.get("text", "") or "").strip()
txt = txt[: int(max_ctx_chars)]
context_text += f"Document {i+1} (Source: {source}, Article: {article}):\n{txt}\n\n"
return context_text
def generate_answer_stream(query: str, retrieved_chunks, max_ctx_chars: int = 1200):
"""
True token streaming with TextIteratorStreamer.
Yields partial strings (the growing answer).
"""
context_text = build_context_text(retrieved_chunks, max_ctx_chars=max_ctx_chars)
user_prompt = f"""سوال: {query}
[Context]:
{context_text}
پاسخ:"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT_FA},
{"role": "user", "content": user_prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
skip_special_tokens=True,
# keep prompt out of the stream (we only want the assistant answer)
skip_prompt=True,
)
gen_kwargs = dict(
**model_inputs,
max_new_tokens=512,
temperature=0.1,
top_p=0.9,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for token_text in streamer:
partial += token_text
yield partial
thread.join()
# ----------------------------
# UI helpers (match your demo)
# ----------------------------
def format_sources(retrieved_docs, max_chars=300):
lines = []
for i, d in enumerate(retrieved_docs, 1):
md = d.get("metadata", {}) or {}
title = md.get("title", "")
src = md.get("source", "")
art = md.get("article", "-")
snippet = (d.get("text", "") or "").strip().replace("\n", " ")
snippet = snippet[:max_chars] + ("…" if len(snippet) > max_chars else "")
lines.append(f"{i}. {title}\n source: {src} | ماده: {art}\n snippet: {snippet}")
return "\n\n".join(lines)
def rag_answer_ui_stream(question, k, max_ctx_chars):
if not question or not question.strip():
yield "لطفاً سوال را وارد کنید.", ""
return
# 1) Retrieve
retrieved = hybrid_search(question, k=int(k))
if not retrieved:
yield "اطلاعاتی در این مورد در آیین‌نامه‌های موجود یافت نشد.", ""
return
# 2) Prepare sources (static; we keep showing it while streaming)
sources_text = format_sources(retrieved)
# 3) Stream answer
for partial_answer in generate_answer_stream(
question,
retrieved,
max_ctx_chars=int(max_ctx_chars),
):
yield partial_answer, sources_text
with gr.Blocks(title="Sharif RAG Demo (Streaming)") as demo:
gr.Markdown(
"## 🎓 Sharif Regulations RAG Bot (Streaming)\n"
"سوال خود را وارد کنید. پاسخ فقط بر اساس متن‌های بازیابی‌شده تولید می‌شود."
)
with gr.Row():
question = gr.Textbox(
label="❓ Question (Persian)",
placeholder="مثلاً: شرایط مهمانی در دوره روزانه؟",
lines=2,
)
with gr.Row():
k = gr.Slider(1, 6, value=DEFAULT_K, step=1, label="🔎 Number of retrieved chunks (k)")
max_ctx_chars = gr.Slider(300, 2500, value=DEFAULT_MAX_CTX_CHARS, step=100, label="✂️ Max chars per chunk (for generation)")
run_btn = gr.Button("Run RAG (stream)")
answer_out = gr.Textbox(label="🤖 Answer (streaming)", lines=10)
sources_out = gr.Textbox(label="📚 Retrieved sources (debug)", lines=12)
run_btn.click(
fn=rag_answer_ui_stream,
inputs=[question, k, max_ctx_chars],
outputs=[answer_out, sources_out],
)
# Spaces will call app.py; server_name makes it work in containers too
demo.queue().launch(server_name="0.0.0.0", server_port=7860)