"""Streamlit UI for the IPL RAG + tool calling demo.""" import json import os from pathlib import Path from typing import Dict, Any, List import numpy as np import pandas as pd import streamlit as st import chromadb from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder from openai import OpenAI # For Hugging Face Spaces: assume data files are in the same directory as app.py SCRIPT_DIR = Path(__file__).parent.resolve() DATA_PATH = SCRIPT_DIR / 'ipl_knowledge_base.json' CSV_PATH = SCRIPT_DIR / 'cricket_data.csv' VECTOR_DIR = SCRIPT_DIR / 'vector_store' VECTOR_DIR.mkdir(parents=True, exist_ok=True) FETCH_K = 8 CONTEXT_K = 4 COLLECTION_NAME = 'ipl_rag_ui' ALLOWED_WEB_DOMAINS = [ 'www.cricbuzz.com', 'www.espncricinfo.com', 'www.iplt20.com', 'www.bcci.tv', 'www.hindustantimes.com/cricket' ] st.set_page_config(page_title='IPL RAG Copilot', page_icon='🏏', layout='wide') st.title('🏏 IPL RAG Copilot') st.caption('Chroma vector DB + multi-rerank pipeline for IPL insights') api_key = st.sidebar.text_input('OPENAI_API_KEY', value=os.environ.get('OPENAI_API_KEY', ''), type='password') st.sidebar.markdown('Provide a key, then rebuild the vector store if needed.') rerank_strategy = st.sidebar.selectbox('Rerank strategy', ['cross_encoder', 'bm25', 'none'], index=0) if api_key: os.environ['OPENAI_API_KEY'] = api_key @st.cache_data(show_spinner=False) def load_kb() -> Dict[str, Any]: """Load knowledge base JSON file.""" # Ensure DATA_PATH is a Path object data_path = Path(DATA_PATH) with open(data_path, 'r', encoding='utf-8') as f: return json.load(f) @st.cache_data(show_spinner=False) def load_stats_df() -> pd.DataFrame: """Load and preprocess CSV stats.""" csv_path = Path(CSV_PATH) df = pd.read_csv(csv_path) df = df[df['Year'] != 'No stats'].copy() df['Year'] = pd.to_numeric(df['Year']) numeric_cols = [ 'Matches_Batted', 'Not_Outs', 'Runs_Scored', 'Batting_Average', 'Balls_Faced', 'Batting_Strike_Rate', 'Centuries', 'Half_Centuries', 'Fours', 'Sixes', 'Matches_Bowled', 'Balls_Bowled', 'Runs_Conceded', 'Wickets_Taken', 'Bowling_Average', 'Economy_Rate', 'Bowling_Strike_Rate', 'Four_Wicket_Hauls', 'Five_Wicket_Hauls' ] for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce') return df def summarize_player_stats(player_df: pd.DataFrame) -> str: recent = player_df.sort_values('Year', ascending=False).head(4) batting_parts = [] bowling_parts = [] for _, row in recent.iterrows(): if not np.isnan(row.get('Runs_Scored', np.nan)): batting_parts.append( f"{int(row.Year)}: {int(row.Matches_Batted or 0)} inns, " f"{int(row.Runs_Scored or 0)} runs @ avg {row.Batting_Average:.1f} SR {row.Batting_Strike_Rate:.1f}" ) if not np.isnan(row.get('Wickets_Taken', np.nan)): bowling_parts.append( f"{int(row.Year)}: {int(row.Matches_Bowled or 0)} matches, " f"{int(row.Wickets_Taken or 0)} wkts @ avg {row.Bowling_Average:.1f} eco {row.Economy_Rate:.2f}" ) batting_summary = '; '.join(batting_parts) if batting_parts else 'No batting sample' bowling_summary = '; '.join(bowling_parts) if bowling_parts else 'No bowling sample' return f"Recent batting -> {batting_summary}. Recent bowling -> {bowling_summary}." def build_corpus(kb: Dict[str, Any], stats_df: pd.DataFrame) -> List[Dict[str, Any]]: docs: List[Dict[str, Any]] = [] for match in kb['matches']: docs.append({ 'id': match['match_id'], 'type': 'match', 'text': ( f"{match['stage']} match {match['match_id']} at {match['venue']} featuring {match['teams']}. " f"Result: {match['result']}. Theme: {match['theme']}. Key events: {'; '.join(match['key_events'])}. " f"Impact player: {match['impact_player']}. Tactical notes: {match['tactical_notes']}" ) }) for player in kb['player_profiles']: docs.append({ 'id': player['player'], 'type': 'player', 'text': ( f"{player['player']} ({player['team']}) roles {', '.join(player['roles'])}. " f"Strengths: {', '.join(player['strengths'])}. Highlights: {player['season_highlights']}. " f"Impact metrics: {player.get('impact_metrics')}" ) }) for team in kb['team_briefs']: docs.append({ 'id': f"team-{team['team']}", 'type': 'team', 'text': f"{team['team']} led by {team['captain']} (coach {team['coach']}) style: {team['style']}" }) for venue in kb['stadium_notes']: docs.append({ 'id': f"venue-{venue['venue']}", 'type': 'venue', 'text': f"{venue['venue']} avg first innings {venue['average_first_innings']} runs; note: {venue['spice']}" }) season = kb['season_summary'] docs.append({ 'id': f"season-{season['season']}", 'type': 'season', 'text': ( f"Season narrative: {season['narrative']}. " f"Trends: {'; '.join(season['trends'])}. Awards: {season['awards']}" ) }) for player_name, player_df in stats_df.groupby('Player_Name'): docs.append({ 'id': f"stats-{player_name}", 'type': 'stat', 'text': f"Stat capsule for {player_name}: {summarize_player_stats(player_df)}" }) return docs def embed_texts(texts: List[str]) -> np.ndarray: client = OpenAI() response = client.embeddings.create(model='text-embedding-3-small', input=texts) return np.array([item.embedding for item in response.data], dtype='float32') def vector_client() -> chromadb.PersistentClient: return chromadb.PersistentClient(path=str(VECTOR_DIR)) @st.cache_resource(show_spinner=False) def get_cross_encoder() -> CrossEncoder: return CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') @st.cache_resource(show_spinner=True) def init_vector_store(kb_dict: Dict[str, Any], stats_payload: str): stats_df = pd.read_json(stats_payload) corpus = build_corpus(kb_dict, stats_df) client = vector_client() try: client.delete_collection(name=COLLECTION_NAME) except Exception: pass collection = client.get_or_create_collection(name=COLLECTION_NAME, metadata={'hnsw:space': 'cosine'}) embeddings = embed_texts([doc['text'] for doc in corpus]) collection.add( ids=[doc['id'] for doc in corpus], documents=[doc['text'] for doc in corpus], metadatas=[{'type': doc['type']} for doc in corpus], embeddings=[vec.tolist() for vec in embeddings] ) return corpus, collection def vector_search(question: str, collection) -> List[Dict[str, Any]]: query_vec = embed_texts([question])[0].tolist() result = collection.query(query_embeddings=[query_vec], n_results=FETCH_K, include=['documents', 'metadatas', 'distances']) hits = [] for distance, doc_id, text, meta in zip(result['distances'][0], result['ids'][0], result['documents'][0], result['metadatas'][0]): identifier = doc_id or meta.get('id') or f"doc-{len(hits)}" hits.append({'id': identifier, 'type': meta.get('type', 'unknown'), 'text': text, 'score': float(1 - distance)}) return hits def rerank_none(query: str, candidates: List[Dict[str, Any]]): return candidates[:CONTEXT_K] def rerank_bm25(query: str, candidates: List[Dict[str, Any]]): if not candidates: return [] tokenized = [doc['text'].lower().split() for doc in candidates] bm25 = BM25Okapi(tokenized) scores = bm25.get_scores(query.lower().split()) enriched = [] for doc, score in zip(candidates, scores): updated = doc.copy() updated['rerank_score'] = float(score) enriched.append(updated) enriched.sort(key=lambda d: d.get('rerank_score', 0), reverse=True) return enriched[:CONTEXT_K] def rerank_cross_encoder(query: str, candidates: List[Dict[str, Any]]): if not candidates: return [] model = get_cross_encoder() pairs = [[query, doc['text']] for doc in candidates] scores = model.predict(pairs) enriched = [] for doc, score in zip(candidates, scores): updated = doc.copy() updated['rerank_score'] = float(score) enriched.append(updated) enriched.sort(key=lambda d: d.get('rerank_score', 0), reverse=True) return enriched[:CONTEXT_K] RERANKERS = { 'cross_encoder': rerank_cross_encoder, 'bm25': rerank_bm25, 'none': rerank_none } def openai_websearch(query: str, max_results: int = 4) -> Dict[str, Any]: client = OpenAI() response = client.responses.create( model='gpt-4.1-mini', input=[ {'role': 'system', 'content': [{'type': 'text', 'text': 'Return concise bullet summaries citing sources.'}]}, {'role': 'user', 'content': [{'type': 'input_text', 'text': query}]} ], tools=[{'type': 'web_search'}], web_search={'query': query, 'allowed_domains': ALLOWED_WEB_DOMAINS, 'max_results': max_results}, temperature=0 ) summary = getattr(response, 'output_text', None) if not summary: try: summary = response.model_dump_json(indent=2) except Exception: summary = str(response) return {'query': query, 'summary': summary} def get_match_facts(kb: Dict[str, Any], match_id: str) -> Dict[str, Any]: match = next((m for m in kb['matches'] if m['match_id'].lower() == match_id.lower()), None) if not match: return {'error': f'Unknown match {match_id}'} return { 'match_id': match['match_id'], 'stage': match['stage'], 'venue': match['venue'], 'result': match['result'], 'theme': match['theme'], 'top_batters': match['top_performers']['batting'], 'top_bowlers': match['top_performers']['bowling'], 'impact_player': match['impact_player'], 'tactical_notes': match['tactical_notes'] } def get_player_profile(kb: Dict[str, Any], player_name: str) -> Dict[str, Any]: profile = next((p for p in kb['player_profiles'] if p['player'].lower() == player_name.lower()), None) if not profile: return {'error': f'Unknown player {player_name}'} return profile def get_player_stats(stats_df: pd.DataFrame, player_name: str, recent_rows: int = 5) -> Dict[str, Any]: player_rows = stats_df[stats_df['Player_Name'].str.lower() == player_name.lower()].sort_values('Year', ascending=False) if player_rows.empty: return {'error': f'No stats found for {player_name}'} sliced = player_rows.head(recent_rows) cols = ['Year', 'Matches_Batted', 'Runs_Scored', 'Batting_Average', 'Batting_Strike_Rate', 'Matches_Bowled', 'Wickets_Taken', 'Bowling_Average', 'Economy_Rate'] return { 'player': player_name, 'seasons': sliced[cols].to_dict(orient='records') } def run_agent(question: str, kb: Dict[str, Any], stats_df: pd.DataFrame, collection, rerank_key: str): base_hits = vector_search(question, collection) reranker = RERANKERS.get(rerank_key, rerank_cross_encoder) contexts = reranker(question, base_hits) context_block = '\n\n'.join([f"[{c.get('type','doc')}::{c.get('id','unknown')}] {c['text']}" for c in contexts]) dream11_hint = '' lowered = question.lower() if any(token in lowered for token in ['dream11', 'dream 11', 'playing xi', 'playing 11', 'best 11']): dream11_hint = '\n\nFocus on returning a balanced fantasy XI (2 WKs max, at least 3 bowlers) with cited rationale.' system_prompt = 'You are an IPL analyst who cites vector DB evidence and calls tools when needed.' user_prompt = f"Question: {question}\n\nContext:\n{context_block}\n\nRerank strategy: {rerank_key}{dream11_hint}" tools = [ { 'type': 'function', 'function': { 'name': 'get_match_facts', 'description': 'Return structured stats for a match id such as IPL2024-42', 'parameters': {'type': 'object', 'properties': {'match_id': {'type': 'string'}}, 'required': ['match_id']} } }, { 'type': 'function', 'function': { 'name': 'get_player_profile', 'description': 'Return roles and highlights for a player', 'parameters': {'type': 'object', 'properties': {'player_name': {'type': 'string'}}, 'required': ['player_name']} } }, { 'type': 'function', 'function': { 'name': 'get_player_stats', 'description': 'Return recent CSV stats for a player', 'parameters': {'type': 'object', 'properties': {'player_name': {'type': 'string'}, 'recent_rows': {'type': 'integer', 'default': 5}}, 'required': ['player_name']} } }, { 'type': 'function', 'function': { 'name': 'openai_websearch', 'description': 'Use OpenAI web search limited to Cricbuzz/ESPNCricinfo/IPLT20/BCCI/HT.', 'parameters': {'type': 'object', 'properties': {'query': {'type': 'string'}, 'max_results': {'type': 'integer', 'default': 4}}, 'required': ['query']} } } ] client = OpenAI() first = client.chat.completions.create( model='gpt-4.1-mini', messages=[ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': user_prompt} ], tools=tools, tool_choice='auto', temperature=0.2 ) message = first.choices[0].message if not message.tool_calls: return message.content, contexts tool_messages = [] available_tools = { 'get_match_facts': lambda **kwargs: get_match_facts(kb, **kwargs), 'get_player_profile': lambda **kwargs: get_player_profile(kb, **kwargs), 'get_player_stats': lambda **kwargs: get_player_stats(stats_df, **kwargs), 'openai_websearch': openai_websearch } for call in message.tool_calls: func = available_tools[call.function.name] params = json.loads(call.function.arguments) payload = func(**params) tool_messages.append({ 'role': 'tool', 'tool_call_id': call.id, 'name': call.function.name, 'content': json.dumps(payload) }) follow_up = [ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': user_prompt}, message, *tool_messages ] second = client.chat.completions.create( model='gpt-4.1-mini', messages=follow_up, temperature=0.2 ) return second.choices[0].message.content, contexts # Main execution try: kb = load_kb() stats_df = load_stats_df() stats_payload = stats_df.to_json(orient='records') if st.sidebar.button('Build / refresh vector store', disabled=not api_key): init_vector_store.clear() st.sidebar.success('Rebuilt vector store') if not api_key: st.warning('Provide an OpenAI API key to run the agent.') st.stop() corpus, collection = init_vector_store(kb, stats_payload) query = st.text_area('Ask anything about IPL 2024 (matches, players, venues, tactics)', height=140) if st.button('Run query', disabled=not query.strip()): with st.spinner('Calling vector DB + RAG agent...'): answer, contexts = run_agent(query.strip(), kb, stats_df, collection, rerank_strategy) st.success('Answer') st.write(answer) with st.expander('Retrieved context'): for ctx in contexts: sim = ctx.get('score', 0.0) rerank_score = ctx.get('rerank_score') suffix = f", rerank={rerank_score:.2f}" if rerank_score is not None else '' st.markdown(f"**{ctx.get('type','doc')}::{ctx.get('id','unknown')}** (sim={sim:.2f}{suffix})") st.write(ctx['text']) st.divider() else: st.info('Enter a query and click run to test the pipeline.') except FileNotFoundError as e: st.error(f'Data file not found: {e}') st.info(f'Looking for files in: {SCRIPT_DIR}') st.info('Please ensure ipl_knowledge_base.json and cricket_data.csv are in the same directory as app.py') except Exception as e: st.error(f'Error loading application: {e}') import traceback st.code(traceback.format_exc())