Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import streamlit as st | |
| import chromadb | |
| from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| from weather import get_weather_summary | |
| from travel import get_travel_spots | |
| import spacy | |
| # Ensure a fresh Chroma DB directory (temporary path) | |
| CHROMA_PATH = "/tmp/chroma" | |
| if os.path.exists(CHROMA_PATH): | |
| shutil.rmtree(CHROMA_PATH) | |
| os.makedirs(CHROMA_PATH, exist_ok=True) | |
| # Initialize Chroma using the new client API | |
| client = chromadb.PersistentClient( | |
| path=CHROMA_PATH, | |
| settings=Settings(), | |
| tenant=DEFAULT_TENANT, | |
| database=DEFAULT_DATABASE, | |
| ) | |
| db = client.get_or_create_collection("disaster_news") | |
| # Load models | |
| embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base") | |
| nlp = spacy.load("en_core_web_sm") | |
| # Keyword lists | |
| weather_keywords = ["weather", "forecast", "rain", "snow", "temperature", "wind", "climate", "humid", "cold", "hot"] | |
| travel_keywords = ["visit", "travel", "tourist", "see", "go", "spots", "places", "explore", "attractions"] | |
| # Extract location and intent | |
| def extract_location_and_intent(query): | |
| doc = nlp(query) | |
| locations = [ent.text for ent in doc.ents if ent.label_ in ("GPE", "LOC")] | |
| location = locations[0] if locations else None | |
| is_weather = any(word in query.lower() for word in weather_keywords) | |
| is_travel = any(word in query.lower() for word in travel_keywords) | |
| return location, is_weather, is_travel | |
| # Fallback: semantic search + QA | |
| def query_rag_system(query): | |
| query_emb = embed_model.encode(query).tolist() | |
| results = db.query(query_embeddings=[query_emb], n_results=5) | |
| retrieved_docs = results.get("documents", [[]])[0] | |
| if not retrieved_docs: | |
| return "No relevant disaster info found in the database." | |
| context = "\n".join(retrieved_docs) | |
| prompt = ( | |
| f"You are a helpful assistant. Based on the context below, answer the question.\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"Question: {query}\n\n" | |
| f"Answer clearly and specifically:" | |
| ) | |
| output = qa_pipeline(prompt, max_new_tokens=200)[0]["generated_text"] | |
| if output.strip().lower().startswith("here's what i found"): | |
| output = output.replace("Here's what I found:", "").strip() | |
| return output if output else "No relevant answer could be generated." | |