gk2410's picture
v3
8244968 verified
import os
import shutil
import streamlit as st
import chromadb
from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from weather import get_weather_summary
from travel import get_travel_spots
import spacy
# Ensure a fresh Chroma DB directory (temporary path)
CHROMA_PATH = "/tmp/chroma"
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)
os.makedirs(CHROMA_PATH, exist_ok=True)
# Initialize Chroma using the new client API
client = chromadb.PersistentClient(
path=CHROMA_PATH,
settings=Settings(),
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
)
db = client.get_or_create_collection("disaster_news")
# Load models
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
nlp = spacy.load("en_core_web_sm")
# Keyword lists
weather_keywords = ["weather", "forecast", "rain", "snow", "temperature", "wind", "climate", "humid", "cold", "hot"]
travel_keywords = ["visit", "travel", "tourist", "see", "go", "spots", "places", "explore", "attractions"]
# Extract location and intent
def extract_location_and_intent(query):
doc = nlp(query)
locations = [ent.text for ent in doc.ents if ent.label_ in ("GPE", "LOC")]
location = locations[0] if locations else None
is_weather = any(word in query.lower() for word in weather_keywords)
is_travel = any(word in query.lower() for word in travel_keywords)
return location, is_weather, is_travel
# Fallback: semantic search + QA
def query_rag_system(query):
query_emb = embed_model.encode(query).tolist()
results = db.query(query_embeddings=[query_emb], n_results=5)
retrieved_docs = results.get("documents", [[]])[0]
if not retrieved_docs:
return "No relevant disaster info found in the database."
context = "\n".join(retrieved_docs)
prompt = (
f"You are a helpful assistant. Based on the context below, answer the question.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer clearly and specifically:"
)
output = qa_pipeline(prompt, max_new_tokens=200)[0]["generated_text"]
if output.strip().lower().startswith("here's what i found"):
output = output.replace("Here's what I found:", "").strip()
return output if output else "No relevant answer could be generated."