NaturalDisasters / rag_engine.py
gk2410's picture
Update rag_engine.py
9664fd3 verified
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from chromadb import PersistentClient
from chromadb.config import Settings
from weather import get_weather_summary
import re
# Load embedding model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# Initialize persistent Chroma client (use a folder that survives HF Spaces runtime if possible)
client = PersistentClient(path="./chroma_db", settings=Settings(anonymized_telemetry=False))
# Set collection name
collection_name = "disaster_news"
# Ensure collection exists
if collection_name not in [c.name for c in client.list_collections()]:
db = client.create_collection(name=collection_name)
else:
db = client.get_collection(name=collection_name)
# Load QA model
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
# Simple location extractor (can be improved later)
def extract_location(text):
match = re.search(r'in ([A-Z][a-zA-Z\s]+)', text)
return match.group(1) if match else None
# Main RAG query function
def query_rag_system(query):
# Check if DB has docs
if db.count() == 0:
return "⚠️ The disaster knowledge base is empty. Please ingest documents before querying."
# Get query embedding
query_emb = embed_model.encode(query).tolist()
# Query the vector DB
results = db.query(query_embeddings=[query_emb], n_results=5)
retrieved_docs = [doc for doc in results['documents'][0]]
context = "\n".join(retrieved_docs)
# Generate QA output
prompt = f"Answer the question using the context below:\nContext:\n{context}\n\nQuestion: {query}"
response = qa_pipeline(prompt, max_new_tokens=200)[0]['generated_text']
# Weather forecast (if any location can be detected)
location = extract_location(query) or extract_location(response)
forecast = get_weather_summary(location) if location else "\n\n🌦️ No forecast available for location."
return f"{response}\n\n{forecast}"