import os import shutil import streamlit as st import chromadb from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE from sentence_transformers import SentenceTransformer from transformers import pipeline from weather import get_weather_summary from travel import get_travel_spots import spacy # Ensure a fresh Chroma DB directory (temporary path) CHROMA_PATH = "/tmp/chroma" if os.path.exists(CHROMA_PATH): shutil.rmtree(CHROMA_PATH) os.makedirs(CHROMA_PATH, exist_ok=True) # Initialize Chroma using the new client API client = chromadb.PersistentClient( path=CHROMA_PATH, settings=Settings(), tenant=DEFAULT_TENANT, database=DEFAULT_DATABASE, ) db = client.get_or_create_collection("disaster_news") # Load models embed_model = SentenceTransformer("all-MiniLM-L6-v2") qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base") nlp = spacy.load("en_core_web_sm") # Keyword lists weather_keywords = ["weather", "forecast", "rain", "snow", "temperature", "wind", "climate", "humid", "cold", "hot"] travel_keywords = ["visit", "travel", "tourist", "see", "go", "spots", "places", "explore", "attractions"] # Extract location and intent def extract_location_and_intent(query): doc = nlp(query) locations = [ent.text for ent in doc.ents if ent.label_ in ("GPE", "LOC")] location = locations[0] if locations else None is_weather = any(word in query.lower() for word in weather_keywords) is_travel = any(word in query.lower() for word in travel_keywords) return location, is_weather, is_travel # Fallback: semantic search + QA def query_rag_system(query): query_emb = embed_model.encode(query).tolist() results = db.query(query_embeddings=[query_emb], n_results=5) retrieved_docs = results.get("documents", [[]])[0] if not retrieved_docs: return "No relevant disaster info found in the database." context = "\n".join(retrieved_docs) prompt = ( f"You are a helpful assistant. Based on the context below, answer the question.\n\n" f"Context:\n{context}\n\n" f"Question: {query}\n\n" f"Answer clearly and specifically:" ) output = qa_pipeline(prompt, max_new_tokens=200)[0]["generated_text"] if output.strip().lower().startswith("here's what i found"): output = output.replace("Here's what I found:", "").strip() return output if output else "No relevant answer could be generated."