import faiss import pickle from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import os from process_dataset import process_medical_dataset # or your pdf ingestion # Ensure FAISS index exists if not os.path.exists("faiss_index/index.pkl"): print("No FAISS index found. Generating now...") process_medical_dataset() # Or call process_uploaded_pdf("data/uploaded.pdf") # Load FAISS index and chunks with open("faiss_index/index.pkl", "rb") as f: chunks = pickle.load(f) index = faiss.read_index("faiss_index/index.faiss") # Load the locally saved model and tokenizer model_name = "google/flan-t5-small" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Move model to appropriate device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Embedding model (same as before) embedding_model = SentenceTransformer("all-MiniLM-L6-v2") def get_top_k_chunks(question, k=3): question_embedding = embedding_model.encode([question]) distances, indices = index.search(question_embedding, k) return [chunks[i] for i in indices[0] if i < len(chunks)] def build_prompt(question, top_chunks): context = "\n".join(top_chunks) return f"Answer the question based on the context below:\n\nContext:\n{context}\n\nQuestion: {question}\nAnswer:" def get_answer(prompt): inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) outputs = model.generate(**inputs, max_new_tokens=150) return tokenizer.decode(outputs[0], skip_special_tokens=True)