apollo / src /streamlit_app.py
guramritpal-saggu-12's picture
Update src/streamlit_app.py
82eba10 verified
"""Streamlit UI for the IPL RAG + tool calling demo."""
import json
import os
from pathlib import Path
from typing import Dict, Any, List
import numpy as np
import pandas as pd
import streamlit as st
import chromadb
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from openai import OpenAI
# For Hugging Face Spaces: assume data files are in the same directory as app.py
SCRIPT_DIR = Path(__file__).parent.resolve()
DATA_PATH = SCRIPT_DIR / 'ipl_knowledge_base.json'
CSV_PATH = SCRIPT_DIR / 'cricket_data.csv'
VECTOR_DIR = SCRIPT_DIR / 'vector_store'
VECTOR_DIR.mkdir(parents=True, exist_ok=True)
FETCH_K = 8
CONTEXT_K = 4
COLLECTION_NAME = 'ipl_rag_ui'
ALLOWED_WEB_DOMAINS = [
'www.cricbuzz.com',
'www.espncricinfo.com',
'www.iplt20.com',
'www.bcci.tv',
'www.hindustantimes.com/cricket'
]
st.set_page_config(page_title='IPL RAG Copilot', page_icon='🏏', layout='wide')
st.title('🏏 IPL RAG Copilot')
st.caption('Chroma vector DB + multi-rerank pipeline for IPL insights')
api_key = st.sidebar.text_input('OPENAI_API_KEY', value=os.environ.get('OPENAI_API_KEY', ''), type='password')
st.sidebar.markdown('Provide a key, then rebuild the vector store if needed.')
rerank_strategy = st.sidebar.selectbox('Rerank strategy', ['cross_encoder', 'bm25', 'none'], index=0)
if api_key:
os.environ['OPENAI_API_KEY'] = api_key
@st.cache_data(show_spinner=False)
def load_kb() -> Dict[str, Any]:
"""Load knowledge base JSON file."""
# Ensure DATA_PATH is a Path object
data_path = Path(DATA_PATH)
with open(data_path, 'r', encoding='utf-8') as f:
return json.load(f)
@st.cache_data(show_spinner=False)
def load_stats_df() -> pd.DataFrame:
"""Load and preprocess CSV stats."""
csv_path = Path(CSV_PATH)
df = pd.read_csv(csv_path)
df = df[df['Year'] != 'No stats'].copy()
df['Year'] = pd.to_numeric(df['Year'])
numeric_cols = [
'Matches_Batted', 'Not_Outs', 'Runs_Scored', 'Batting_Average', 'Balls_Faced', 'Batting_Strike_Rate',
'Centuries', 'Half_Centuries', 'Fours', 'Sixes', 'Matches_Bowled', 'Balls_Bowled', 'Runs_Conceded',
'Wickets_Taken', 'Bowling_Average', 'Economy_Rate', 'Bowling_Strike_Rate', 'Four_Wicket_Hauls', 'Five_Wicket_Hauls'
]
for col in numeric_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
return df
def summarize_player_stats(player_df: pd.DataFrame) -> str:
recent = player_df.sort_values('Year', ascending=False).head(4)
batting_parts = []
bowling_parts = []
for _, row in recent.iterrows():
if not np.isnan(row.get('Runs_Scored', np.nan)):
batting_parts.append(
f"{int(row.Year)}: {int(row.Matches_Batted or 0)} inns, "
f"{int(row.Runs_Scored or 0)} runs @ avg {row.Batting_Average:.1f} SR {row.Batting_Strike_Rate:.1f}"
)
if not np.isnan(row.get('Wickets_Taken', np.nan)):
bowling_parts.append(
f"{int(row.Year)}: {int(row.Matches_Bowled or 0)} matches, "
f"{int(row.Wickets_Taken or 0)} wkts @ avg {row.Bowling_Average:.1f} eco {row.Economy_Rate:.2f}"
)
batting_summary = '; '.join(batting_parts) if batting_parts else 'No batting sample'
bowling_summary = '; '.join(bowling_parts) if bowling_parts else 'No bowling sample'
return f"Recent batting -> {batting_summary}. Recent bowling -> {bowling_summary}."
def build_corpus(kb: Dict[str, Any], stats_df: pd.DataFrame) -> List[Dict[str, Any]]:
docs: List[Dict[str, Any]] = []
for match in kb['matches']:
docs.append({
'id': match['match_id'],
'type': 'match',
'text': (
f"{match['stage']} match {match['match_id']} at {match['venue']} featuring {match['teams']}. "
f"Result: {match['result']}. Theme: {match['theme']}. Key events: {'; '.join(match['key_events'])}. "
f"Impact player: {match['impact_player']}. Tactical notes: {match['tactical_notes']}"
)
})
for player in kb['player_profiles']:
docs.append({
'id': player['player'],
'type': 'player',
'text': (
f"{player['player']} ({player['team']}) roles {', '.join(player['roles'])}. "
f"Strengths: {', '.join(player['strengths'])}. Highlights: {player['season_highlights']}. "
f"Impact metrics: {player.get('impact_metrics')}"
)
})
for team in kb['team_briefs']:
docs.append({
'id': f"team-{team['team']}",
'type': 'team',
'text': f"{team['team']} led by {team['captain']} (coach {team['coach']}) style: {team['style']}"
})
for venue in kb['stadium_notes']:
docs.append({
'id': f"venue-{venue['venue']}",
'type': 'venue',
'text': f"{venue['venue']} avg first innings {venue['average_first_innings']} runs; note: {venue['spice']}"
})
season = kb['season_summary']
docs.append({
'id': f"season-{season['season']}",
'type': 'season',
'text': (
f"Season narrative: {season['narrative']}. "
f"Trends: {'; '.join(season['trends'])}. Awards: {season['awards']}"
)
})
for player_name, player_df in stats_df.groupby('Player_Name'):
docs.append({
'id': f"stats-{player_name}",
'type': 'stat',
'text': f"Stat capsule for {player_name}: {summarize_player_stats(player_df)}"
})
return docs
def embed_texts(texts: List[str]) -> np.ndarray:
client = OpenAI()
response = client.embeddings.create(model='text-embedding-3-small', input=texts)
return np.array([item.embedding for item in response.data], dtype='float32')
def vector_client() -> chromadb.PersistentClient:
return chromadb.PersistentClient(path=str(VECTOR_DIR))
@st.cache_resource(show_spinner=False)
def get_cross_encoder() -> CrossEncoder:
return CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
@st.cache_resource(show_spinner=True)
def init_vector_store(kb_dict: Dict[str, Any], stats_payload: str):
stats_df = pd.read_json(stats_payload)
corpus = build_corpus(kb_dict, stats_df)
client = vector_client()
try:
client.delete_collection(name=COLLECTION_NAME)
except Exception:
pass
collection = client.get_or_create_collection(name=COLLECTION_NAME, metadata={'hnsw:space': 'cosine'})
embeddings = embed_texts([doc['text'] for doc in corpus])
collection.add(
ids=[doc['id'] for doc in corpus],
documents=[doc['text'] for doc in corpus],
metadatas=[{'type': doc['type']} for doc in corpus],
embeddings=[vec.tolist() for vec in embeddings]
)
return corpus, collection
def vector_search(question: str, collection) -> List[Dict[str, Any]]:
query_vec = embed_texts([question])[0].tolist()
result = collection.query(query_embeddings=[query_vec], n_results=FETCH_K, include=['documents', 'metadatas', 'distances'])
hits = []
for distance, doc_id, text, meta in zip(result['distances'][0], result['ids'][0], result['documents'][0], result['metadatas'][0]):
identifier = doc_id or meta.get('id') or f"doc-{len(hits)}"
hits.append({'id': identifier, 'type': meta.get('type', 'unknown'), 'text': text, 'score': float(1 - distance)})
return hits
def rerank_none(query: str, candidates: List[Dict[str, Any]]):
return candidates[:CONTEXT_K]
def rerank_bm25(query: str, candidates: List[Dict[str, Any]]):
if not candidates:
return []
tokenized = [doc['text'].lower().split() for doc in candidates]
bm25 = BM25Okapi(tokenized)
scores = bm25.get_scores(query.lower().split())
enriched = []
for doc, score in zip(candidates, scores):
updated = doc.copy()
updated['rerank_score'] = float(score)
enriched.append(updated)
enriched.sort(key=lambda d: d.get('rerank_score', 0), reverse=True)
return enriched[:CONTEXT_K]
def rerank_cross_encoder(query: str, candidates: List[Dict[str, Any]]):
if not candidates:
return []
model = get_cross_encoder()
pairs = [[query, doc['text']] for doc in candidates]
scores = model.predict(pairs)
enriched = []
for doc, score in zip(candidates, scores):
updated = doc.copy()
updated['rerank_score'] = float(score)
enriched.append(updated)
enriched.sort(key=lambda d: d.get('rerank_score', 0), reverse=True)
return enriched[:CONTEXT_K]
RERANKERS = {
'cross_encoder': rerank_cross_encoder,
'bm25': rerank_bm25,
'none': rerank_none
}
def openai_websearch(query: str, max_results: int = 4) -> Dict[str, Any]:
client = OpenAI()
response = client.responses.create(
model='gpt-4.1-mini',
input=[
{'role': 'system', 'content': [{'type': 'text', 'text': 'Return concise bullet summaries citing sources.'}]},
{'role': 'user', 'content': [{'type': 'input_text', 'text': query}]}
],
tools=[{'type': 'web_search'}],
web_search={'query': query, 'allowed_domains': ALLOWED_WEB_DOMAINS, 'max_results': max_results},
temperature=0
)
summary = getattr(response, 'output_text', None)
if not summary:
try:
summary = response.model_dump_json(indent=2)
except Exception:
summary = str(response)
return {'query': query, 'summary': summary}
def get_match_facts(kb: Dict[str, Any], match_id: str) -> Dict[str, Any]:
match = next((m for m in kb['matches'] if m['match_id'].lower() == match_id.lower()), None)
if not match:
return {'error': f'Unknown match {match_id}'}
return {
'match_id': match['match_id'],
'stage': match['stage'],
'venue': match['venue'],
'result': match['result'],
'theme': match['theme'],
'top_batters': match['top_performers']['batting'],
'top_bowlers': match['top_performers']['bowling'],
'impact_player': match['impact_player'],
'tactical_notes': match['tactical_notes']
}
def get_player_profile(kb: Dict[str, Any], player_name: str) -> Dict[str, Any]:
profile = next((p for p in kb['player_profiles'] if p['player'].lower() == player_name.lower()), None)
if not profile:
return {'error': f'Unknown player {player_name}'}
return profile
def get_player_stats(stats_df: pd.DataFrame, player_name: str, recent_rows: int = 5) -> Dict[str, Any]:
player_rows = stats_df[stats_df['Player_Name'].str.lower() == player_name.lower()].sort_values('Year', ascending=False)
if player_rows.empty:
return {'error': f'No stats found for {player_name}'}
sliced = player_rows.head(recent_rows)
cols = ['Year', 'Matches_Batted', 'Runs_Scored', 'Batting_Average', 'Batting_Strike_Rate',
'Matches_Bowled', 'Wickets_Taken', 'Bowling_Average', 'Economy_Rate']
return {
'player': player_name,
'seasons': sliced[cols].to_dict(orient='records')
}
def run_agent(question: str, kb: Dict[str, Any], stats_df: pd.DataFrame, collection, rerank_key: str):
base_hits = vector_search(question, collection)
reranker = RERANKERS.get(rerank_key, rerank_cross_encoder)
contexts = reranker(question, base_hits)
context_block = '\n\n'.join([f"[{c.get('type','doc')}::{c.get('id','unknown')}] {c['text']}" for c in contexts])
dream11_hint = ''
lowered = question.lower()
if any(token in lowered for token in ['dream11', 'dream 11', 'playing xi', 'playing 11', 'best 11']):
dream11_hint = '\n\nFocus on returning a balanced fantasy XI (2 WKs max, at least 3 bowlers) with cited rationale.'
system_prompt = 'You are an IPL analyst who cites vector DB evidence and calls tools when needed.'
user_prompt = f"Question: {question}\n\nContext:\n{context_block}\n\nRerank strategy: {rerank_key}{dream11_hint}"
tools = [
{
'type': 'function',
'function': {
'name': 'get_match_facts',
'description': 'Return structured stats for a match id such as IPL2024-42',
'parameters': {'type': 'object', 'properties': {'match_id': {'type': 'string'}}, 'required': ['match_id']}
}
},
{
'type': 'function',
'function': {
'name': 'get_player_profile',
'description': 'Return roles and highlights for a player',
'parameters': {'type': 'object', 'properties': {'player_name': {'type': 'string'}}, 'required': ['player_name']}
}
},
{
'type': 'function',
'function': {
'name': 'get_player_stats',
'description': 'Return recent CSV stats for a player',
'parameters': {'type': 'object', 'properties': {'player_name': {'type': 'string'}, 'recent_rows': {'type': 'integer', 'default': 5}}, 'required': ['player_name']}
}
},
{
'type': 'function',
'function': {
'name': 'openai_websearch',
'description': 'Use OpenAI web search limited to Cricbuzz/ESPNCricinfo/IPLT20/BCCI/HT.',
'parameters': {'type': 'object', 'properties': {'query': {'type': 'string'}, 'max_results': {'type': 'integer', 'default': 4}}, 'required': ['query']}
}
}
]
client = OpenAI()
first = client.chat.completions.create(
model='gpt-4.1-mini',
messages=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt}
],
tools=tools,
tool_choice='auto',
temperature=0.2
)
message = first.choices[0].message
if not message.tool_calls:
return message.content, contexts
tool_messages = []
available_tools = {
'get_match_facts': lambda **kwargs: get_match_facts(kb, **kwargs),
'get_player_profile': lambda **kwargs: get_player_profile(kb, **kwargs),
'get_player_stats': lambda **kwargs: get_player_stats(stats_df, **kwargs),
'openai_websearch': openai_websearch
}
for call in message.tool_calls:
func = available_tools[call.function.name]
params = json.loads(call.function.arguments)
payload = func(**params)
tool_messages.append({
'role': 'tool',
'tool_call_id': call.id,
'name': call.function.name,
'content': json.dumps(payload)
})
follow_up = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt},
message,
*tool_messages
]
second = client.chat.completions.create(
model='gpt-4.1-mini',
messages=follow_up,
temperature=0.2
)
return second.choices[0].message.content, contexts
# Main execution
try:
kb = load_kb()
stats_df = load_stats_df()
stats_payload = stats_df.to_json(orient='records')
if st.sidebar.button('Build / refresh vector store', disabled=not api_key):
init_vector_store.clear()
st.sidebar.success('Rebuilt vector store')
if not api_key:
st.warning('Provide an OpenAI API key to run the agent.')
st.stop()
corpus, collection = init_vector_store(kb, stats_payload)
query = st.text_area('Ask anything about IPL 2024 (matches, players, venues, tactics)', height=140)
if st.button('Run query', disabled=not query.strip()):
with st.spinner('Calling vector DB + RAG agent...'):
answer, contexts = run_agent(query.strip(), kb, stats_df, collection, rerank_strategy)
st.success('Answer')
st.write(answer)
with st.expander('Retrieved context'):
for ctx in contexts:
sim = ctx.get('score', 0.0)
rerank_score = ctx.get('rerank_score')
suffix = f", rerank={rerank_score:.2f}" if rerank_score is not None else ''
st.markdown(f"**{ctx.get('type','doc')}::{ctx.get('id','unknown')}** (sim={sim:.2f}{suffix})")
st.write(ctx['text'])
st.divider()
else:
st.info('Enter a query and click run to test the pipeline.')
except FileNotFoundError as e:
st.error(f'Data file not found: {e}')
st.info(f'Looking for files in: {SCRIPT_DIR}')
st.info('Please ensure ipl_knowledge_base.json and cricket_data.csv are in the same directory as app.py')
except Exception as e:
st.error(f'Error loading application: {e}')
import traceback
st.code(traceback.format_exc())