vishurdx's picture
Update app.py
2297b37 verified
import os
import pandas as pd
import numpy as np
import kagglehub
import google.generativeai as genai
from typing import List
from dataclasses import dataclass, field
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# --- 1. Data Classes ---
@dataclass(frozen=True)
class EventVenue:
name: str = ""
address: str = ""
@dataclass(frozen=True)
class Event:
name: str = ""
url: str = ""
description: str = ""
start_datetime: str = ""
end_datetime: str = ""
location: str = ""
organizer: str = ""
contact: str = ""
cost: str = ""
event_type: str = ""
primary_venue: EventVenue = field(default_factory=lambda: EventVenue())
# --- 2. RAG & Agent Classes ---
class EventRAGPipeline:
def __init__(self, events: List[Event], embedding_model: str = 'all-MiniLM-L6-v2'):
self.events = events
print("Loading Embedding Model...")
self.model = SentenceTransformer(embedding_model)
print("Computing Embeddings...")
texts = [self._event_to_text(e) for e in events]
self.event_embeddings = self.model.encode(texts)
def _event_to_text(self, event: Event) -> str:
loc_text = event.location if event.location else event.organizer
text_parts = [event.name, event.description, loc_text, event.event_type]
return ' '.join([str(x) for x in text_parts if x])
def query_events(self, query: str, top_k: int = 5) -> List[Event]:
query_embedding = self.model.encode(query).reshape(1, -1)
similarities = cosine_similarity(query_embedding, self.event_embeddings)[0]
top_indices = similarities.argsort()[::-1][:top_k]
return [self.events[idx] for idx in top_indices]
class GenerativeAgent:
def __init__(self, pipeline, api_key: str):
self.pipeline = pipeline
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel('gemini-2.5-flash')
def generate_answer(self, user_query: str) -> str:
relevant_events = self.pipeline.query_events(user_query)
context_str = "\n".join([f"Name: {e.name}, Date: {e.start_datetime}, Desc: {e.description}" for e in relevant_events])
prompt = f"""
You are an LA Event Concierge. Answer based on this context:
{context_str}
User: {user_query}
"""
try:
return self.model.generate_content(prompt).text
except Exception as e:
return f"AI Error: {str(e)}"
# --- 3. App Initialization ---
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
agent = None
def load_data():
global agent
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
print("WARNING: GOOGLE_API_KEY not set.")
return
print("Downloading dataset...")
path = kagglehub.dataset_download("cityofLA/los-angeles-festival-guide-2014-events")
files = [f for f in os.listdir(path) if f.endswith('.csv')]
df = pd.read_csv(os.path.join(path, files[0])).fillna("")
events = [Event(
name=str(row['event_name']),
description=str(row['event_description']),
start_datetime=str(row['start_date_and_time']),
location=str(row['event_location']),
event_type=str(row['event_type'])
) for _, row in df.iterrows()]
pipeline = EventRAGPipeline(events)
agent = GenerativeAgent(pipeline, api_key)
print("System Ready.")
load_data()
# --- 4. API Endpoints ---
class ChatRequest(BaseModel):
message: str
@app.post("/api/chat")
async def chat_endpoint(request: ChatRequest):
if not agent:
raise HTTPException(status_code=503, detail="System initializing or API Key missing")
response = agent.generate_answer(request.message)
return {"response": response}
# --- 5. Serve Frontend (UPDATED) ---
# This explicitly looks for index.html in the CURRENT folder, so you don't need a 'static' folder.
@app.get("/")
async def read_index():
return FileResponse('index.html')