File size: 1,997 Bytes
5d76799
 
c562093
 
5d76799
 
 
c562093
5d76799
 
c562093
 
 
 
 
 
 
 
 
 
 
 
 
5d76799
 
c562093
5d76799
 
 
 
c562093
5d76799
c562093
 
 
 
 
5d76799
 
c562093
 
5d76799
 
 
c562093
5d76799
 
 
c562093
5d76799
 
 
9664fd3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from chromadb import PersistentClient
from chromadb.config import Settings
from weather import get_weather_summary
import re

# Load embedding model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# Initialize persistent Chroma client (use a folder that survives HF Spaces runtime if possible)
client = PersistentClient(path="./chroma_db", settings=Settings(anonymized_telemetry=False))

# Set collection name
collection_name = "disaster_news"

# Ensure collection exists
if collection_name not in [c.name for c in client.list_collections()]:
    db = client.create_collection(name=collection_name)
else:
    db = client.get_collection(name=collection_name)

# Load QA model
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")

# Simple location extractor (can be improved later)
def extract_location(text):
    match = re.search(r'in ([A-Z][a-zA-Z\s]+)', text)
    return match.group(1) if match else None

# Main RAG query function
def query_rag_system(query):
    # Check if DB has docs
    if db.count() == 0:
        return "⚠️ The disaster knowledge base is empty. Please ingest documents before querying."

    # Get query embedding
    query_emb = embed_model.encode(query).tolist()

    # Query the vector DB
    results = db.query(query_embeddings=[query_emb], n_results=5)
    retrieved_docs = [doc for doc in results['documents'][0]]
    context = "\n".join(retrieved_docs)

    # Generate QA output
    prompt = f"Answer the question using the context below:\nContext:\n{context}\n\nQuestion: {query}"
    response = qa_pipeline(prompt, max_new_tokens=200)[0]['generated_text']

    # Weather forecast (if any location can be detected)
    location = extract_location(query) or extract_location(response)
    forecast = get_weather_summary(location) if location else "\n\n🌦️ No forecast available for location."

    return f"{response}\n\n{forecast}"