|
|
|
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import faiss |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig |
|
|
|
|
|
from huggingface_hub import login |
|
|
import os |
|
|
|
|
|
login(os.getenv("HF_TOKEN")) |
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "microsoft/DialoGPT-small" |
|
|
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
DATA_DIR = Path("data") |
|
|
FAISS_INDEX_PATH = DATA_DIR / "vector_store.index" |
|
|
META_PATH = DATA_DIR / "metadata.json" |
|
|
|
|
|
|
|
|
TOP_K = 4 |
|
|
MAX_CONTEXT_TOKENS = 800 |
|
|
|
|
|
|
|
|
GEN_MAX_NEW_TOKENS = 128 |
|
|
TEMPERATURE = 0.7 |
|
|
TOP_P = 0.9 |
|
|
|
|
|
|
|
|
|
|
|
def load_faiss_index(index_path: Path): |
|
|
if not index_path.exists(): |
|
|
raise FileNotFoundError(f"FAISS index not found at {index_path}") |
|
|
index = faiss.read_index(str(index_path)) |
|
|
return index |
|
|
|
|
|
def load_metadata(meta_path: Path): |
|
|
if not meta_path.exists(): |
|
|
raise FileNotFoundError(f"metadata.json not found at {meta_path}") |
|
|
with open(meta_path, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
|
|
|
def embed_texts(model, texts: List[str]): |
|
|
embs = model.encode(texts, convert_to_numpy=True, show_progress_bar=False) |
|
|
|
|
|
faiss.normalize_L2(embs) |
|
|
return embs |
|
|
|
|
|
def retrieve_top_k(query: str, embed_model, faiss_index, metadata, top_k: int = TOP_K): |
|
|
q_emb = embed_model.encode([query], convert_to_numpy=True) |
|
|
faiss.normalize_L2(q_emb) |
|
|
D, I = faiss_index.search(q_emb.astype('float32'), top_k) |
|
|
results = [] |
|
|
for score, idx in zip(D[0], I[0]): |
|
|
meta = metadata[idx] |
|
|
results.append({"score": float(score), "text": meta.get("text_full") or meta.get("text"), "meta": meta}) |
|
|
return results |
|
|
|
|
|
def build_prompt_from_chunks(question: str, chunks: List[dict]): |
|
|
""" |
|
|
Create a simpler prompt for smaller models |
|
|
""" |
|
|
context_parts = [] |
|
|
total_chars = 0 |
|
|
for ch in chunks: |
|
|
t = ch["text"] |
|
|
|
|
|
if total_chars + len(t) > 2000: |
|
|
break |
|
|
context_parts.append(f"Source: {ch['meta'].get('source_file','unknown')} - {t[:400]}") |
|
|
total_chars += len(t) |
|
|
|
|
|
context = "\n\n".join(context_parts).strip() |
|
|
|
|
|
|
|
|
prompt = f"""Context: {context} |
|
|
|
|
|
Question: {question} |
|
|
Answer:""" |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_generation_model(model_id: str): |
|
|
""" |
|
|
Load model with proper BitsAndBytesConfig as suggested by GPT |
|
|
""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Loading model {model_id} on device={device}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
try: |
|
|
|
|
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
quantization_config=bnb_config, |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("Model loaded with 4-bit quantization on GPU") |
|
|
except Exception as e: |
|
|
print("4-bit load failed:", e) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("Model loaded with fp16 on GPU (no quantization)") |
|
|
else: |
|
|
|
|
|
try: |
|
|
print("Loading on CPU (quantization disabled)") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("Model loaded on CPU") |
|
|
except Exception as e: |
|
|
print("CPU load failed, attempting with device_map:", e) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
print("Initializing resources...") |
|
|
embed_model = SentenceTransformer(EMBED_MODEL) |
|
|
faiss_index = load_faiss_index(FAISS_INDEX_PATH) |
|
|
metadata = load_metadata(META_PATH) |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer, gen_model = load_generation_model(MODEL_ID) |
|
|
generation_config = GenerationConfig( |
|
|
max_new_tokens=GEN_MAX_NEW_TOKENS, |
|
|
temperature=TEMPERATURE, |
|
|
top_p=TOP_P, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
print("Model loaded.") |
|
|
except Exception as e: |
|
|
print("Model load error:", e) |
|
|
tokenizer, gen_model = None, None |
|
|
|
|
|
|
|
|
def answer_question(user_question: str): |
|
|
if not user_question or user_question.strip() == "": |
|
|
return "Please enter a question." |
|
|
|
|
|
|
|
|
retrieved = retrieve_top_k(user_question, embed_model, faiss_index, metadata, top_k=TOP_K) |
|
|
if not retrieved: |
|
|
return "No relevant content found in the vector store." |
|
|
|
|
|
|
|
|
prompt = build_prompt_from_chunks(user_question, retrieved) |
|
|
|
|
|
|
|
|
if gen_model is None or tokenizer is None: |
|
|
preview = "Model not loaded. Here are the retrieved contexts:\n\n" |
|
|
for i, r in enumerate(retrieved, 1): |
|
|
preview += f"\n--- Result {i} (score={r['score']:.3f}) ---\nSource: {r['meta'].get('source_file')} page {r['meta'].get('page')}\n{r['text'][:1000]}\n" |
|
|
return preview |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS) |
|
|
|
|
|
|
|
|
if hasattr(gen_model, 'device'): |
|
|
inputs = inputs.to(gen_model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
gen_ids = gen_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=GEN_MAX_NEW_TOKENS, |
|
|
temperature=TEMPERATURE, |
|
|
top_p=TOP_P, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
out = tokenizer.decode(gen_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if out.startswith(prompt): |
|
|
answer = out[len(prompt):].strip() |
|
|
else: |
|
|
answer = out.strip() |
|
|
|
|
|
return answer if answer else "I couldn't generate a proper response." |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Virtual Teacher — RAG + Mistral-7B") as demo: |
|
|
gr.Markdown("# Virtual Teacher — RAG powered") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
question = gr.Textbox(lines=3, label="Ask a question (about uploaded PDFs)") |
|
|
ask_btn = gr.Button("Ask") |
|
|
output = gr.Textbox(lines=18, label="Answer") |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Retrieved Contexts (preview)") |
|
|
contexts = gr.Markdown("") |
|
|
info = gr.Markdown("Model: {}\nEmbed model: {}\nTop-K: {}".format(MODEL_ID, EMBED_MODEL, TOP_K)) |
|
|
|
|
|
def on_ask(q): |
|
|
|
|
|
retrieved = retrieve_top_k(q, embed_model, faiss_index, metadata, top_k=TOP_K) |
|
|
ctx_preview = "" |
|
|
for i, r in enumerate(retrieved, 1): |
|
|
ctx_preview += f"**{i}. Source:** {r['meta'].get('source_file')} (page {r['meta'].get('page')}) \nScore: {r['score']:.3f}\n\n" |
|
|
txt = r['text'] |
|
|
ctx_preview += txt[:1000] + ("..." if len(txt) > 1000 else "") + "\n\n" |
|
|
|
|
|
answer = answer_question(q) |
|
|
return answer, ctx_preview |
|
|
|
|
|
ask_btn.click(on_ask, inputs=[question], outputs=[output, contexts]) |
|
|
|
|
|
gr.Markdown("**Notes:** If the model doesn't load in this Space (OOM), try enabling Community GPU or use a smaller model.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |