smartedu / app.py
Bishal Sharma
Update app.py
a069a62 verified
# app.py
import os
import json
from pathlib import Path
from typing import List
import gradio as gr
import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
from huggingface_hub import login
import os
login(os.getenv("HF_TOKEN"))
# ---------- CONFIG ----------
MODEL_ID = "microsoft/DialoGPT-small" # Much smaller model (~500MB instead of 14GB)
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DATA_DIR = Path("data")
FAISS_INDEX_PATH = DATA_DIR / "vector_store.index"
META_PATH = DATA_DIR / "metadata.json"
# RAG settings
TOP_K = 4 # number of chunks to retrieve
MAX_CONTEXT_TOKENS = 800 # Reduced for smaller model
# Generation settings
GEN_MAX_NEW_TOKENS = 128 # Reduced for smaller model
TEMPERATURE = 0.7
TOP_P = 0.9
# ----------------------------
# ---------- Helpers ----------
def load_faiss_index(index_path: Path):
if not index_path.exists():
raise FileNotFoundError(f"FAISS index not found at {index_path}")
index = faiss.read_index(str(index_path))
return index
def load_metadata(meta_path: Path):
if not meta_path.exists():
raise FileNotFoundError(f"metadata.json not found at {meta_path}")
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def embed_texts(model, texts: List[str]):
embs = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
# normalize for cosine
faiss.normalize_L2(embs)
return embs
def retrieve_top_k(query: str, embed_model, faiss_index, metadata, top_k: int = TOP_K):
q_emb = embed_model.encode([query], convert_to_numpy=True)
faiss.normalize_L2(q_emb)
D, I = faiss_index.search(q_emb.astype('float32'), top_k)
results = []
for score, idx in zip(D[0], I[0]):
meta = metadata[idx]
results.append({"score": float(score), "text": meta.get("text_full") or meta.get("text"), "meta": meta})
return results
def build_prompt_from_chunks(question: str, chunks: List[dict]):
"""
Create a simpler prompt for smaller models
"""
context_parts = []
total_chars = 0
for ch in chunks:
t = ch["text"]
# Much smaller context for smaller models
if total_chars + len(t) > 2000: # Reduced from 15000
break
context_parts.append(f"Source: {ch['meta'].get('source_file','unknown')} - {t[:400]}") # Truncate chunks
total_chars += len(t)
context = "\n\n".join(context_parts).strip()
# Much simpler prompt for smaller models
prompt = f"""Context: {context}
Question: {question}
Answer:"""
return prompt
# ---------- Model Loading ----------
@torch.no_grad()
def load_generation_model(model_id: str):
"""
Load model with proper BitsAndBytesConfig as suggested by GPT
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model {model_id} on device={device}")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
# Add padding token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Choose loading kwargs depending on device
if device == "cuda":
try:
# Use GPT suggested BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16,
trust_remote_code=True,
)
print("Model loaded with 4-bit quantization on GPU")
except Exception as e:
print("4-bit load failed:", e)
# fallback to fp16 if 4-bit fails
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
print("Model loaded with fp16 on GPU (no quantization)")
else:
# CPU fallback: avoid quantization on CPU as it causes issues
try:
print("Loading on CPU (quantization disabled)")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
print("Model loaded on CPU")
except Exception as e:
print("CPU load failed, attempting with device_map:", e)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
trust_remote_code=True
)
return tokenizer, model
# ---------- Initialization ----------
print("Initializing resources...")
embed_model = SentenceTransformer(EMBED_MODEL)
faiss_index = load_faiss_index(FAISS_INDEX_PATH)
metadata = load_metadata(META_PATH)
# Try to load tokenizer + model (this can take a while)
try:
tokenizer, gen_model = load_generation_model(MODEL_ID)
generation_config = GenerationConfig(
max_new_tokens=GEN_MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print("Model loaded.")
except Exception as e:
print("Model load error:", e)
tokenizer, gen_model = None, None
# ---------- RAG + Generation ----------
def answer_question(user_question: str):
if not user_question or user_question.strip() == "":
return "Please enter a question."
# Retrieve chunks
retrieved = retrieve_top_k(user_question, embed_model, faiss_index, metadata, top_k=TOP_K)
if not retrieved:
return "No relevant content found in the vector store."
# Build prompt
prompt = build_prompt_from_chunks(user_question, retrieved)
# If model not loaded, return retrieved chunks as fallback
if gen_model is None or tokenizer is None:
preview = "Model not loaded. Here are the retrieved contexts:\n\n"
for i, r in enumerate(retrieved, 1):
preview += f"\n--- Result {i} (score={r['score']:.3f}) ---\nSource: {r['meta'].get('source_file')} page {r['meta'].get('page')}\n{r['text'][:1000]}\n"
return preview
# Tokenize and move to device
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS)
# Move inputs to same device as model
if hasattr(gen_model, 'device'):
inputs = inputs.to(gen_model.device)
# generate
with torch.no_grad():
gen_ids = gen_model.generate(
**inputs,
max_new_tokens=GEN_MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
)
out = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# Strip the prompt from output if tokenizer returns full text
if out.startswith(prompt):
answer = out[len(prompt):].strip()
else:
answer = out.strip()
return answer if answer else "I couldn't generate a proper response."
# ---------- Gradio UI ----------
with gr.Blocks(title="Virtual Teacher — RAG + Mistral-7B") as demo:
gr.Markdown("# Virtual Teacher — RAG powered")
with gr.Row():
with gr.Column(scale=3):
question = gr.Textbox(lines=3, label="Ask a question (about uploaded PDFs)")
ask_btn = gr.Button("Ask")
output = gr.Textbox(lines=18, label="Answer")
with gr.Column(scale=1):
gr.Markdown("### Retrieved Contexts (preview)")
contexts = gr.Markdown("")
info = gr.Markdown("Model: {}\nEmbed model: {}\nTop-K: {}".format(MODEL_ID, EMBED_MODEL, TOP_K))
def on_ask(q):
# run retrieval first to show contexts
retrieved = retrieve_top_k(q, embed_model, faiss_index, metadata, top_k=TOP_K)
ctx_preview = ""
for i, r in enumerate(retrieved, 1):
ctx_preview += f"**{i}. Source:** {r['meta'].get('source_file')} (page {r['meta'].get('page')}) \nScore: {r['score']:.3f}\n\n"
txt = r['text']
ctx_preview += txt[:1000] + ("..." if len(txt) > 1000 else "") + "\n\n"
# generate answer (this is the slow step)
answer = answer_question(q)
return answer, ctx_preview
ask_btn.click(on_ask, inputs=[question], outputs=[output, contexts])
gr.Markdown("**Notes:** If the model doesn't load in this Space (OOM), try enabling Community GPU or use a smaller model.")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))