# app.py import os import json from pathlib import Path from typing import List import gradio as gr import numpy as np import faiss import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig from huggingface_hub import login import os login(os.getenv("HF_TOKEN")) # ---------- CONFIG ---------- MODEL_ID = "microsoft/DialoGPT-small" # Much smaller model (~500MB instead of 14GB) EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" DATA_DIR = Path("data") FAISS_INDEX_PATH = DATA_DIR / "vector_store.index" META_PATH = DATA_DIR / "metadata.json" # RAG settings TOP_K = 4 # number of chunks to retrieve MAX_CONTEXT_TOKENS = 800 # Reduced for smaller model # Generation settings GEN_MAX_NEW_TOKENS = 128 # Reduced for smaller model TEMPERATURE = 0.7 TOP_P = 0.9 # ---------------------------- # ---------- Helpers ---------- def load_faiss_index(index_path: Path): if not index_path.exists(): raise FileNotFoundError(f"FAISS index not found at {index_path}") index = faiss.read_index(str(index_path)) return index def load_metadata(meta_path: Path): if not meta_path.exists(): raise FileNotFoundError(f"metadata.json not found at {meta_path}") with open(meta_path, "r", encoding="utf-8") as f: return json.load(f) def embed_texts(model, texts: List[str]): embs = model.encode(texts, convert_to_numpy=True, show_progress_bar=False) # normalize for cosine faiss.normalize_L2(embs) return embs def retrieve_top_k(query: str, embed_model, faiss_index, metadata, top_k: int = TOP_K): q_emb = embed_model.encode([query], convert_to_numpy=True) faiss.normalize_L2(q_emb) D, I = faiss_index.search(q_emb.astype('float32'), top_k) results = [] for score, idx in zip(D[0], I[0]): meta = metadata[idx] results.append({"score": float(score), "text": meta.get("text_full") or meta.get("text"), "meta": meta}) return results def build_prompt_from_chunks(question: str, chunks: List[dict]): """ Create a simpler prompt for smaller models """ context_parts = [] total_chars = 0 for ch in chunks: t = ch["text"] # Much smaller context for smaller models if total_chars + len(t) > 2000: # Reduced from 15000 break context_parts.append(f"Source: {ch['meta'].get('source_file','unknown')} - {t[:400]}") # Truncate chunks total_chars += len(t) context = "\n\n".join(context_parts).strip() # Much simpler prompt for smaller models prompt = f"""Context: {context} Question: {question} Answer:""" return prompt # ---------- Model Loading ---------- @torch.no_grad() def load_generation_model(model_id: str): """ Load model with proper BitsAndBytesConfig as suggested by GPT """ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model {model_id} on device={device}") tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True) # Add padding token if missing if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Choose loading kwargs depending on device if device == "cuda": try: # Use GPT suggested BitsAndBytesConfig bnb_config = BitsAndBytesConfig(load_in_4bit=True) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", quantization_config=bnb_config, torch_dtype=torch.float16, trust_remote_code=True, ) print("Model loaded with 4-bit quantization on GPU") except Exception as e: print("4-bit load failed:", e) # fallback to fp16 if 4-bit fails model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, ) print("Model loaded with fp16 on GPU (no quantization)") else: # CPU fallback: avoid quantization on CPU as it causes issues try: print("Loading on CPU (quantization disabled)") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, ) print("Model loaded on CPU") except Exception as e: print("CPU load failed, attempting with device_map:", e) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", trust_remote_code=True ) return tokenizer, model # ---------- Initialization ---------- print("Initializing resources...") embed_model = SentenceTransformer(EMBED_MODEL) faiss_index = load_faiss_index(FAISS_INDEX_PATH) metadata = load_metadata(META_PATH) # Try to load tokenizer + model (this can take a while) try: tokenizer, gen_model = load_generation_model(MODEL_ID) generation_config = GenerationConfig( max_new_tokens=GEN_MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) print("Model loaded.") except Exception as e: print("Model load error:", e) tokenizer, gen_model = None, None # ---------- RAG + Generation ---------- def answer_question(user_question: str): if not user_question or user_question.strip() == "": return "Please enter a question." # Retrieve chunks retrieved = retrieve_top_k(user_question, embed_model, faiss_index, metadata, top_k=TOP_K) if not retrieved: return "No relevant content found in the vector store." # Build prompt prompt = build_prompt_from_chunks(user_question, retrieved) # If model not loaded, return retrieved chunks as fallback if gen_model is None or tokenizer is None: preview = "Model not loaded. Here are the retrieved contexts:\n\n" for i, r in enumerate(retrieved, 1): preview += f"\n--- Result {i} (score={r['score']:.3f}) ---\nSource: {r['meta'].get('source_file')} page {r['meta'].get('page')}\n{r['text'][:1000]}\n" return preview # Tokenize and move to device inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS) # Move inputs to same device as model if hasattr(gen_model, 'device'): inputs = inputs.to(gen_model.device) # generate with torch.no_grad(): gen_ids = gen_model.generate( **inputs, max_new_tokens=GEN_MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, do_sample=True, ) out = tokenizer.decode(gen_ids[0], skip_special_tokens=True) # Strip the prompt from output if tokenizer returns full text if out.startswith(prompt): answer = out[len(prompt):].strip() else: answer = out.strip() return answer if answer else "I couldn't generate a proper response." # ---------- Gradio UI ---------- with gr.Blocks(title="Virtual Teacher — RAG + Mistral-7B") as demo: gr.Markdown("# Virtual Teacher — RAG powered") with gr.Row(): with gr.Column(scale=3): question = gr.Textbox(lines=3, label="Ask a question (about uploaded PDFs)") ask_btn = gr.Button("Ask") output = gr.Textbox(lines=18, label="Answer") with gr.Column(scale=1): gr.Markdown("### Retrieved Contexts (preview)") contexts = gr.Markdown("") info = gr.Markdown("Model: {}\nEmbed model: {}\nTop-K: {}".format(MODEL_ID, EMBED_MODEL, TOP_K)) def on_ask(q): # run retrieval first to show contexts retrieved = retrieve_top_k(q, embed_model, faiss_index, metadata, top_k=TOP_K) ctx_preview = "" for i, r in enumerate(retrieved, 1): ctx_preview += f"**{i}. Source:** {r['meta'].get('source_file')} (page {r['meta'].get('page')}) \nScore: {r['score']:.3f}\n\n" txt = r['text'] ctx_preview += txt[:1000] + ("..." if len(txt) > 1000 else "") + "\n\n" # generate answer (this is the slow step) answer = answer_question(q) return answer, ctx_preview ask_btn.click(on_ask, inputs=[question], outputs=[output, contexts]) gr.Markdown("**Notes:** If the model doesn't load in this Space (OOM), try enabling Community GPU or use a smaller model.") if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))