Spaces:
Running
Running
| """Streamlit UI for the IPL RAG + tool calling demo.""" | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Any, List | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import chromadb | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import CrossEncoder | |
| from openai import OpenAI | |
| # For Hugging Face Spaces: assume data files are in the same directory as app.py | |
| SCRIPT_DIR = Path(__file__).parent.resolve() | |
| DATA_PATH = SCRIPT_DIR / 'ipl_knowledge_base.json' | |
| CSV_PATH = SCRIPT_DIR / 'cricket_data.csv' | |
| VECTOR_DIR = SCRIPT_DIR / 'vector_store' | |
| VECTOR_DIR.mkdir(parents=True, exist_ok=True) | |
| FETCH_K = 8 | |
| CONTEXT_K = 4 | |
| COLLECTION_NAME = 'ipl_rag_ui' | |
| ALLOWED_WEB_DOMAINS = [ | |
| 'www.cricbuzz.com', | |
| 'www.espncricinfo.com', | |
| 'www.iplt20.com', | |
| 'www.bcci.tv', | |
| 'www.hindustantimes.com/cricket' | |
| ] | |
| st.set_page_config(page_title='IPL RAG Copilot', page_icon='π', layout='wide') | |
| st.title('π IPL RAG Copilot') | |
| st.caption('Chroma vector DB + multi-rerank pipeline for IPL insights') | |
| api_key = st.sidebar.text_input('OPENAI_API_KEY', value=os.environ.get('OPENAI_API_KEY', ''), type='password') | |
| st.sidebar.markdown('Provide a key, then rebuild the vector store if needed.') | |
| rerank_strategy = st.sidebar.selectbox('Rerank strategy', ['cross_encoder', 'bm25', 'none'], index=0) | |
| if api_key: | |
| os.environ['OPENAI_API_KEY'] = api_key | |
| def load_kb() -> Dict[str, Any]: | |
| """Load knowledge base JSON file.""" | |
| # Ensure DATA_PATH is a Path object | |
| data_path = Path(DATA_PATH) | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| def load_stats_df() -> pd.DataFrame: | |
| """Load and preprocess CSV stats.""" | |
| csv_path = Path(CSV_PATH) | |
| df = pd.read_csv(csv_path) | |
| df = df[df['Year'] != 'No stats'].copy() | |
| df['Year'] = pd.to_numeric(df['Year']) | |
| numeric_cols = [ | |
| 'Matches_Batted', 'Not_Outs', 'Runs_Scored', 'Batting_Average', 'Balls_Faced', 'Batting_Strike_Rate', | |
| 'Centuries', 'Half_Centuries', 'Fours', 'Sixes', 'Matches_Bowled', 'Balls_Bowled', 'Runs_Conceded', | |
| 'Wickets_Taken', 'Bowling_Average', 'Economy_Rate', 'Bowling_Strike_Rate', 'Four_Wicket_Hauls', 'Five_Wicket_Hauls' | |
| ] | |
| for col in numeric_cols: | |
| df[col] = pd.to_numeric(df[col], errors='coerce') | |
| return df | |
| def summarize_player_stats(player_df: pd.DataFrame) -> str: | |
| recent = player_df.sort_values('Year', ascending=False).head(4) | |
| batting_parts = [] | |
| bowling_parts = [] | |
| for _, row in recent.iterrows(): | |
| if not np.isnan(row.get('Runs_Scored', np.nan)): | |
| batting_parts.append( | |
| f"{int(row.Year)}: {int(row.Matches_Batted or 0)} inns, " | |
| f"{int(row.Runs_Scored or 0)} runs @ avg {row.Batting_Average:.1f} SR {row.Batting_Strike_Rate:.1f}" | |
| ) | |
| if not np.isnan(row.get('Wickets_Taken', np.nan)): | |
| bowling_parts.append( | |
| f"{int(row.Year)}: {int(row.Matches_Bowled or 0)} matches, " | |
| f"{int(row.Wickets_Taken or 0)} wkts @ avg {row.Bowling_Average:.1f} eco {row.Economy_Rate:.2f}" | |
| ) | |
| batting_summary = '; '.join(batting_parts) if batting_parts else 'No batting sample' | |
| bowling_summary = '; '.join(bowling_parts) if bowling_parts else 'No bowling sample' | |
| return f"Recent batting -> {batting_summary}. Recent bowling -> {bowling_summary}." | |
| def build_corpus(kb: Dict[str, Any], stats_df: pd.DataFrame) -> List[Dict[str, Any]]: | |
| docs: List[Dict[str, Any]] = [] | |
| for match in kb['matches']: | |
| docs.append({ | |
| 'id': match['match_id'], | |
| 'type': 'match', | |
| 'text': ( | |
| f"{match['stage']} match {match['match_id']} at {match['venue']} featuring {match['teams']}. " | |
| f"Result: {match['result']}. Theme: {match['theme']}. Key events: {'; '.join(match['key_events'])}. " | |
| f"Impact player: {match['impact_player']}. Tactical notes: {match['tactical_notes']}" | |
| ) | |
| }) | |
| for player in kb['player_profiles']: | |
| docs.append({ | |
| 'id': player['player'], | |
| 'type': 'player', | |
| 'text': ( | |
| f"{player['player']} ({player['team']}) roles {', '.join(player['roles'])}. " | |
| f"Strengths: {', '.join(player['strengths'])}. Highlights: {player['season_highlights']}. " | |
| f"Impact metrics: {player.get('impact_metrics')}" | |
| ) | |
| }) | |
| for team in kb['team_briefs']: | |
| docs.append({ | |
| 'id': f"team-{team['team']}", | |
| 'type': 'team', | |
| 'text': f"{team['team']} led by {team['captain']} (coach {team['coach']}) style: {team['style']}" | |
| }) | |
| for venue in kb['stadium_notes']: | |
| docs.append({ | |
| 'id': f"venue-{venue['venue']}", | |
| 'type': 'venue', | |
| 'text': f"{venue['venue']} avg first innings {venue['average_first_innings']} runs; note: {venue['spice']}" | |
| }) | |
| season = kb['season_summary'] | |
| docs.append({ | |
| 'id': f"season-{season['season']}", | |
| 'type': 'season', | |
| 'text': ( | |
| f"Season narrative: {season['narrative']}. " | |
| f"Trends: {'; '.join(season['trends'])}. Awards: {season['awards']}" | |
| ) | |
| }) | |
| for player_name, player_df in stats_df.groupby('Player_Name'): | |
| docs.append({ | |
| 'id': f"stats-{player_name}", | |
| 'type': 'stat', | |
| 'text': f"Stat capsule for {player_name}: {summarize_player_stats(player_df)}" | |
| }) | |
| return docs | |
| def embed_texts(texts: List[str]) -> np.ndarray: | |
| client = OpenAI() | |
| response = client.embeddings.create(model='text-embedding-3-small', input=texts) | |
| return np.array([item.embedding for item in response.data], dtype='float32') | |
| def vector_client() -> chromadb.PersistentClient: | |
| return chromadb.PersistentClient(path=str(VECTOR_DIR)) | |
| def get_cross_encoder() -> CrossEncoder: | |
| return CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
| def init_vector_store(kb_dict: Dict[str, Any], stats_payload: str): | |
| stats_df = pd.read_json(stats_payload) | |
| corpus = build_corpus(kb_dict, stats_df) | |
| client = vector_client() | |
| try: | |
| client.delete_collection(name=COLLECTION_NAME) | |
| except Exception: | |
| pass | |
| collection = client.get_or_create_collection(name=COLLECTION_NAME, metadata={'hnsw:space': 'cosine'}) | |
| embeddings = embed_texts([doc['text'] for doc in corpus]) | |
| collection.add( | |
| ids=[doc['id'] for doc in corpus], | |
| documents=[doc['text'] for doc in corpus], | |
| metadatas=[{'type': doc['type']} for doc in corpus], | |
| embeddings=[vec.tolist() for vec in embeddings] | |
| ) | |
| return corpus, collection | |
| def vector_search(question: str, collection) -> List[Dict[str, Any]]: | |
| query_vec = embed_texts([question])[0].tolist() | |
| result = collection.query(query_embeddings=[query_vec], n_results=FETCH_K, include=['documents', 'metadatas', 'distances']) | |
| hits = [] | |
| for distance, doc_id, text, meta in zip(result['distances'][0], result['ids'][0], result['documents'][0], result['metadatas'][0]): | |
| identifier = doc_id or meta.get('id') or f"doc-{len(hits)}" | |
| hits.append({'id': identifier, 'type': meta.get('type', 'unknown'), 'text': text, 'score': float(1 - distance)}) | |
| return hits | |
| def rerank_none(query: str, candidates: List[Dict[str, Any]]): | |
| return candidates[:CONTEXT_K] | |
| def rerank_bm25(query: str, candidates: List[Dict[str, Any]]): | |
| if not candidates: | |
| return [] | |
| tokenized = [doc['text'].lower().split() for doc in candidates] | |
| bm25 = BM25Okapi(tokenized) | |
| scores = bm25.get_scores(query.lower().split()) | |
| enriched = [] | |
| for doc, score in zip(candidates, scores): | |
| updated = doc.copy() | |
| updated['rerank_score'] = float(score) | |
| enriched.append(updated) | |
| enriched.sort(key=lambda d: d.get('rerank_score', 0), reverse=True) | |
| return enriched[:CONTEXT_K] | |
| def rerank_cross_encoder(query: str, candidates: List[Dict[str, Any]]): | |
| if not candidates: | |
| return [] | |
| model = get_cross_encoder() | |
| pairs = [[query, doc['text']] for doc in candidates] | |
| scores = model.predict(pairs) | |
| enriched = [] | |
| for doc, score in zip(candidates, scores): | |
| updated = doc.copy() | |
| updated['rerank_score'] = float(score) | |
| enriched.append(updated) | |
| enriched.sort(key=lambda d: d.get('rerank_score', 0), reverse=True) | |
| return enriched[:CONTEXT_K] | |
| RERANKERS = { | |
| 'cross_encoder': rerank_cross_encoder, | |
| 'bm25': rerank_bm25, | |
| 'none': rerank_none | |
| } | |
| def openai_websearch(query: str, max_results: int = 4) -> Dict[str, Any]: | |
| client = OpenAI() | |
| response = client.responses.create( | |
| model='gpt-4.1-mini', | |
| input=[ | |
| {'role': 'system', 'content': [{'type': 'text', 'text': 'Return concise bullet summaries citing sources.'}]}, | |
| {'role': 'user', 'content': [{'type': 'input_text', 'text': query}]} | |
| ], | |
| tools=[{'type': 'web_search'}], | |
| web_search={'query': query, 'allowed_domains': ALLOWED_WEB_DOMAINS, 'max_results': max_results}, | |
| temperature=0 | |
| ) | |
| summary = getattr(response, 'output_text', None) | |
| if not summary: | |
| try: | |
| summary = response.model_dump_json(indent=2) | |
| except Exception: | |
| summary = str(response) | |
| return {'query': query, 'summary': summary} | |
| def get_match_facts(kb: Dict[str, Any], match_id: str) -> Dict[str, Any]: | |
| match = next((m for m in kb['matches'] if m['match_id'].lower() == match_id.lower()), None) | |
| if not match: | |
| return {'error': f'Unknown match {match_id}'} | |
| return { | |
| 'match_id': match['match_id'], | |
| 'stage': match['stage'], | |
| 'venue': match['venue'], | |
| 'result': match['result'], | |
| 'theme': match['theme'], | |
| 'top_batters': match['top_performers']['batting'], | |
| 'top_bowlers': match['top_performers']['bowling'], | |
| 'impact_player': match['impact_player'], | |
| 'tactical_notes': match['tactical_notes'] | |
| } | |
| def get_player_profile(kb: Dict[str, Any], player_name: str) -> Dict[str, Any]: | |
| profile = next((p for p in kb['player_profiles'] if p['player'].lower() == player_name.lower()), None) | |
| if not profile: | |
| return {'error': f'Unknown player {player_name}'} | |
| return profile | |
| def get_player_stats(stats_df: pd.DataFrame, player_name: str, recent_rows: int = 5) -> Dict[str, Any]: | |
| player_rows = stats_df[stats_df['Player_Name'].str.lower() == player_name.lower()].sort_values('Year', ascending=False) | |
| if player_rows.empty: | |
| return {'error': f'No stats found for {player_name}'} | |
| sliced = player_rows.head(recent_rows) | |
| cols = ['Year', 'Matches_Batted', 'Runs_Scored', 'Batting_Average', 'Batting_Strike_Rate', | |
| 'Matches_Bowled', 'Wickets_Taken', 'Bowling_Average', 'Economy_Rate'] | |
| return { | |
| 'player': player_name, | |
| 'seasons': sliced[cols].to_dict(orient='records') | |
| } | |
| def run_agent(question: str, kb: Dict[str, Any], stats_df: pd.DataFrame, collection, rerank_key: str): | |
| base_hits = vector_search(question, collection) | |
| reranker = RERANKERS.get(rerank_key, rerank_cross_encoder) | |
| contexts = reranker(question, base_hits) | |
| context_block = '\n\n'.join([f"[{c.get('type','doc')}::{c.get('id','unknown')}] {c['text']}" for c in contexts]) | |
| dream11_hint = '' | |
| lowered = question.lower() | |
| if any(token in lowered for token in ['dream11', 'dream 11', 'playing xi', 'playing 11', 'best 11']): | |
| dream11_hint = '\n\nFocus on returning a balanced fantasy XI (2 WKs max, at least 3 bowlers) with cited rationale.' | |
| system_prompt = 'You are an IPL analyst who cites vector DB evidence and calls tools when needed.' | |
| user_prompt = f"Question: {question}\n\nContext:\n{context_block}\n\nRerank strategy: {rerank_key}{dream11_hint}" | |
| tools = [ | |
| { | |
| 'type': 'function', | |
| 'function': { | |
| 'name': 'get_match_facts', | |
| 'description': 'Return structured stats for a match id such as IPL2024-42', | |
| 'parameters': {'type': 'object', 'properties': {'match_id': {'type': 'string'}}, 'required': ['match_id']} | |
| } | |
| }, | |
| { | |
| 'type': 'function', | |
| 'function': { | |
| 'name': 'get_player_profile', | |
| 'description': 'Return roles and highlights for a player', | |
| 'parameters': {'type': 'object', 'properties': {'player_name': {'type': 'string'}}, 'required': ['player_name']} | |
| } | |
| }, | |
| { | |
| 'type': 'function', | |
| 'function': { | |
| 'name': 'get_player_stats', | |
| 'description': 'Return recent CSV stats for a player', | |
| 'parameters': {'type': 'object', 'properties': {'player_name': {'type': 'string'}, 'recent_rows': {'type': 'integer', 'default': 5}}, 'required': ['player_name']} | |
| } | |
| }, | |
| { | |
| 'type': 'function', | |
| 'function': { | |
| 'name': 'openai_websearch', | |
| 'description': 'Use OpenAI web search limited to Cricbuzz/ESPNCricinfo/IPLT20/BCCI/HT.', | |
| 'parameters': {'type': 'object', 'properties': {'query': {'type': 'string'}, 'max_results': {'type': 'integer', 'default': 4}}, 'required': ['query']} | |
| } | |
| } | |
| ] | |
| client = OpenAI() | |
| first = client.chat.completions.create( | |
| model='gpt-4.1-mini', | |
| messages=[ | |
| {'role': 'system', 'content': system_prompt}, | |
| {'role': 'user', 'content': user_prompt} | |
| ], | |
| tools=tools, | |
| tool_choice='auto', | |
| temperature=0.2 | |
| ) | |
| message = first.choices[0].message | |
| if not message.tool_calls: | |
| return message.content, contexts | |
| tool_messages = [] | |
| available_tools = { | |
| 'get_match_facts': lambda **kwargs: get_match_facts(kb, **kwargs), | |
| 'get_player_profile': lambda **kwargs: get_player_profile(kb, **kwargs), | |
| 'get_player_stats': lambda **kwargs: get_player_stats(stats_df, **kwargs), | |
| 'openai_websearch': openai_websearch | |
| } | |
| for call in message.tool_calls: | |
| func = available_tools[call.function.name] | |
| params = json.loads(call.function.arguments) | |
| payload = func(**params) | |
| tool_messages.append({ | |
| 'role': 'tool', | |
| 'tool_call_id': call.id, | |
| 'name': call.function.name, | |
| 'content': json.dumps(payload) | |
| }) | |
| follow_up = [ | |
| {'role': 'system', 'content': system_prompt}, | |
| {'role': 'user', 'content': user_prompt}, | |
| message, | |
| *tool_messages | |
| ] | |
| second = client.chat.completions.create( | |
| model='gpt-4.1-mini', | |
| messages=follow_up, | |
| temperature=0.2 | |
| ) | |
| return second.choices[0].message.content, contexts | |
| # Main execution | |
| try: | |
| kb = load_kb() | |
| stats_df = load_stats_df() | |
| stats_payload = stats_df.to_json(orient='records') | |
| if st.sidebar.button('Build / refresh vector store', disabled=not api_key): | |
| init_vector_store.clear() | |
| st.sidebar.success('Rebuilt vector store') | |
| if not api_key: | |
| st.warning('Provide an OpenAI API key to run the agent.') | |
| st.stop() | |
| corpus, collection = init_vector_store(kb, stats_payload) | |
| query = st.text_area('Ask anything about IPL 2024 (matches, players, venues, tactics)', height=140) | |
| if st.button('Run query', disabled=not query.strip()): | |
| with st.spinner('Calling vector DB + RAG agent...'): | |
| answer, contexts = run_agent(query.strip(), kb, stats_df, collection, rerank_strategy) | |
| st.success('Answer') | |
| st.write(answer) | |
| with st.expander('Retrieved context'): | |
| for ctx in contexts: | |
| sim = ctx.get('score', 0.0) | |
| rerank_score = ctx.get('rerank_score') | |
| suffix = f", rerank={rerank_score:.2f}" if rerank_score is not None else '' | |
| st.markdown(f"**{ctx.get('type','doc')}::{ctx.get('id','unknown')}** (sim={sim:.2f}{suffix})") | |
| st.write(ctx['text']) | |
| st.divider() | |
| else: | |
| st.info('Enter a query and click run to test the pipeline.') | |
| except FileNotFoundError as e: | |
| st.error(f'Data file not found: {e}') | |
| st.info(f'Looking for files in: {SCRIPT_DIR}') | |
| st.info('Please ensure ipl_knowledge_base.json and cricket_data.csv are in the same directory as app.py') | |
| except Exception as e: | |
| st.error(f'Error loading application: {e}') | |
| import traceback | |
| st.code(traceback.format_exc()) | |