Spaces:
Sleeping
Sleeping
chromadb - v2
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import os
|
| 2 |
import shutil
|
| 3 |
import streamlit as st
|
| 4 |
import chromadb
|
|
@@ -9,15 +9,19 @@ from weather import get_weather_summary
|
|
| 9 |
from travel import get_travel_spots
|
| 10 |
import spacy
|
| 11 |
|
| 12 |
-
#
|
|
|
|
|
|
|
|
|
|
| 13 |
if os.path.exists("/tmp/chroma"):
|
| 14 |
shutil.rmtree("/tmp/chroma")
|
| 15 |
|
| 16 |
-
# Initialize
|
| 17 |
-
|
| 18 |
chroma_db_impl="duckdb+parquet",
|
| 19 |
persist_directory="/tmp/chroma"
|
| 20 |
-
)
|
|
|
|
| 21 |
db = client.get_or_create_collection("disaster_news")
|
| 22 |
|
| 23 |
# Load models
|
|
@@ -25,11 +29,11 @@ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
| 25 |
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
|
| 26 |
nlp = spacy.load("en_core_web_sm")
|
| 27 |
|
| 28 |
-
#
|
| 29 |
weather_keywords = ["weather", "forecast", "rain", "snow", "temperature", "wind", "climate", "humid", "cold", "hot"]
|
| 30 |
travel_keywords = ["visit", "travel", "tourist", "see", "go", "spots", "places", "explore", "attractions"]
|
| 31 |
|
| 32 |
-
# Extract location and intent
|
| 33 |
def extract_location_and_intent(query):
|
| 34 |
doc = nlp(query)
|
| 35 |
locations = [ent.text for ent in doc.ents if ent.label_ in ("GPE", "LOC")]
|
|
@@ -38,7 +42,7 @@ def extract_location_and_intent(query):
|
|
| 38 |
is_travel = any(word in query.lower() for word in travel_keywords)
|
| 39 |
return location, is_weather, is_travel
|
| 40 |
|
| 41 |
-
# Fallback
|
| 42 |
def query_rag_system(query):
|
| 43 |
query_emb = embed_model.encode(query).tolist()
|
| 44 |
results = db.query(query_embeddings=[query_emb], n_results=5)
|
|
@@ -55,20 +59,16 @@ user_input = st.text_input("Ask anything (e.g. 'Rain forecast in Pune', 'Places
|
|
| 55 |
|
| 56 |
if st.button("Submit") and user_input:
|
| 57 |
location, is_weather, is_travel = extract_location_and_intent(user_input)
|
| 58 |
-
|
| 59 |
response_parts = []
|
| 60 |
|
| 61 |
-
# Weather
|
| 62 |
if is_weather and location:
|
| 63 |
weather_summary = get_weather_summary(location)
|
| 64 |
response_parts.append(weather_summary)
|
| 65 |
|
| 66 |
-
# Travel
|
| 67 |
if is_travel and location:
|
| 68 |
travel_suggestions = get_travel_spots(location)
|
| 69 |
response_parts.append(travel_suggestions)
|
| 70 |
|
| 71 |
-
# RAG fallback
|
| 72 |
if not is_weather and not is_travel:
|
| 73 |
rag_response = query_rag_system(user_input)
|
| 74 |
response_parts.append(rag_response)
|
|
@@ -79,3 +79,11 @@ if st.button("Submit") and user_input:
|
|
| 79 |
st.markdown("### 🔎 Response:")
|
| 80 |
for part in response_parts:
|
| 81 |
st.write(part)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chromadb.config import Settingsimport os
|
| 2 |
import shutil
|
| 3 |
import streamlit as st
|
| 4 |
import chromadb
|
|
|
|
| 9 |
from travel import get_travel_spots
|
| 10 |
import spacy
|
| 11 |
|
| 12 |
+
# Force Chroma to use a valid persist directory
|
| 13 |
+
os.environ["PERSIST_DIRECTORY"] = "/tmp/chroma"
|
| 14 |
+
|
| 15 |
+
# Clean up any stale DB state (optional but avoids schema errors)
|
| 16 |
if os.path.exists("/tmp/chroma"):
|
| 17 |
shutil.rmtree("/tmp/chroma")
|
| 18 |
|
| 19 |
+
# Initialize Chroma with valid settings
|
| 20 |
+
settings = Settings(
|
| 21 |
chroma_db_impl="duckdb+parquet",
|
| 22 |
persist_directory="/tmp/chroma"
|
| 23 |
+
)
|
| 24 |
+
client = chromadb.Client(settings)
|
| 25 |
db = client.get_or_create_collection("disaster_news")
|
| 26 |
|
| 27 |
# Load models
|
|
|
|
| 29 |
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
|
| 30 |
nlp = spacy.load("en_core_web_sm")
|
| 31 |
|
| 32 |
+
# Define keyword lists
|
| 33 |
weather_keywords = ["weather", "forecast", "rain", "snow", "temperature", "wind", "climate", "humid", "cold", "hot"]
|
| 34 |
travel_keywords = ["visit", "travel", "tourist", "see", "go", "spots", "places", "explore", "attractions"]
|
| 35 |
|
| 36 |
+
# Extract location and user intent
|
| 37 |
def extract_location_and_intent(query):
|
| 38 |
doc = nlp(query)
|
| 39 |
locations = [ent.text for ent in doc.ents if ent.label_ in ("GPE", "LOC")]
|
|
|
|
| 42 |
is_travel = any(word in query.lower() for word in travel_keywords)
|
| 43 |
return location, is_weather, is_travel
|
| 44 |
|
| 45 |
+
# Fallback: semantic search + QA
|
| 46 |
def query_rag_system(query):
|
| 47 |
query_emb = embed_model.encode(query).tolist()
|
| 48 |
results = db.query(query_embeddings=[query_emb], n_results=5)
|
|
|
|
| 59 |
|
| 60 |
if st.button("Submit") and user_input:
|
| 61 |
location, is_weather, is_travel = extract_location_and_intent(user_input)
|
|
|
|
| 62 |
response_parts = []
|
| 63 |
|
|
|
|
| 64 |
if is_weather and location:
|
| 65 |
weather_summary = get_weather_summary(location)
|
| 66 |
response_parts.append(weather_summary)
|
| 67 |
|
|
|
|
| 68 |
if is_travel and location:
|
| 69 |
travel_suggestions = get_travel_spots(location)
|
| 70 |
response_parts.append(travel_suggestions)
|
| 71 |
|
|
|
|
| 72 |
if not is_weather and not is_travel:
|
| 73 |
rag_response = query_rag_system(user_input)
|
| 74 |
response_parts.append(rag_response)
|
|
|
|
| 79 |
st.markdown("### 🔎 Response:")
|
| 80 |
for part in response_parts:
|
| 81 |
st.write(part)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
settings = Settings(
|
| 85 |
+
chroma_db_impl="duckdb+parquet",
|
| 86 |
+
persist_directory="/tmp/chroma" # must not be None
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
client = chromadb.Client(settings)
|