VisLanRAG / src /rag_engine.py
zach9111's picture
Update src/rag_engine.py
7821233 verified
from huggingface_hub import HfApi
import os
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
api = HfApi(token=hf_token) # Safe: uses token without saving
import faiss
import pickle
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel, pipeline
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval().to("cpu")
qa_pipeline = pipeline(
"text2text-generation",
model="google/flan-t5-small",
tokenizer="google/flan-t5-small",
device=-1 # CPU
)
def embed_text_with_clip(text):
inputs = clip_processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
vec = clip_model.get_text_features(**inputs)[0].numpy()
vec /= np.linalg.norm(vec)
return vec.astype("float32")
def search_similar_pages(question, top_k=5):
if not os.path.exists("vision_index.faiss") or not os.path.exists("metadata.pkl"):
raise FileNotFoundError("PDF not processed yet.")
index = faiss.read_index("vision_index.faiss")
with open("metadata.pkl", "rb") as f:
metadata = pickle.load(f)
query_vec = embed_text_with_clip(question)
distances, indices = index.search(np.array([query_vec]), top_k)
top_pages = [metadata[i] for i in indices[0]]
return top_pages
def ask_local_model(context, question):
prompt = f"Based only on the following text:\n\n{context}\n\nAnswer this question:\n{question}\n\nOnly answer from the text. If unsure, say 'Not found in text.'"
result = qa_pipeline(prompt, max_new_tokens=128, do_sample=False)
return result[0]["generated_text"]
def generate_answer(question):
top_pages = search_similar_pages(question)
answers = []
seen_pages = set()
for page in top_pages:
if page["page"] in seen_pages:
continue
seen_pages.add(page["page"])
answer = ask_local_model(page["ocr"], question)
answers.append({
"page": page["page"],
"thumbnail": page["thumbnail"],
"answer": answer
})
return answers
def summarize_document():
if not os.path.exists("metadata.pkl"):
return "❗ Document not processed yet. Please upload and process a PDF."
with open("metadata.pkl", "rb") as f:
metadata = pickle.load(f)
full_text = " ".join(page["ocr"] for page in metadata).strip()
short_text = full_text[:3000] # Avoid token limit
prompt = f"describe what is the document about and what it contains in 70-80 words:\n\n{short_text}"
result = qa_pipeline(prompt, max_new_tokens=200, do_sample=False)
return result[0]["generated_text"]