llm-chat-assistant / query.py
Vineetha00's picture
Update query.py
3250fb4 verified
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
from process_dataset import process_medical_dataset # or your pdf ingestion
# Ensure FAISS index exists
if not os.path.exists("faiss_index/index.pkl"):
print("No FAISS index found. Generating now...")
process_medical_dataset() # Or call process_uploaded_pdf("data/uploaded.pdf")
# Load FAISS index and chunks
with open("faiss_index/index.pkl", "rb") as f:
chunks = pickle.load(f)
index = faiss.read_index("faiss_index/index.faiss")
# Load the locally saved model and tokenizer
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Move model to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Embedding model (same as before)
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
def get_top_k_chunks(question, k=3):
question_embedding = embedding_model.encode([question])
distances, indices = index.search(question_embedding, k)
return [chunks[i] for i in indices[0] if i < len(chunks)]
def build_prompt(question, top_chunks):
context = "\n".join(top_chunks)
return f"Answer the question based on the context below:\n\nContext:\n{context}\n\nQuestion: {question}\nAnswer:"
def get_answer(prompt):
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
outputs = model.generate(**inputs, max_new_tokens=150)
return tokenizer.decode(outputs[0], skip_special_tokens=True)