Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +53 -36
src/streamlit_app.py
CHANGED
|
@@ -12,11 +12,13 @@ from rank_bm25 import BM25Okapi
|
|
| 12 |
from sentence_transformers import CrossEncoder
|
| 13 |
from openai import OpenAI
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
VECTOR_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 20 |
FETCH_K = 8
|
| 21 |
CONTEXT_K = 4
|
| 22 |
COLLECTION_NAME = 'ipl_rag_ui'
|
|
@@ -42,13 +44,18 @@ if api_key:
|
|
| 42 |
|
| 43 |
@st.cache_data(show_spinner=False)
|
| 44 |
def load_kb() -> Dict[str, Any]:
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
return json.load(f)
|
| 47 |
|
| 48 |
|
| 49 |
@st.cache_data(show_spinner=False)
|
| 50 |
def load_stats_df() -> pd.DataFrame:
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
df = df[df['Year'] != 'No stats'].copy()
|
| 53 |
df['Year'] = pd.to_numeric(df['Year'])
|
| 54 |
numeric_cols = [
|
|
@@ -373,33 +380,43 @@ def run_agent(question: str, kb: Dict[str, Any], stats_df: pd.DataFrame, collect
|
|
| 373 |
return second.choices[0].message.content, contexts
|
| 374 |
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
st.sidebar.
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from sentence_transformers import CrossEncoder
|
| 13 |
from openai import OpenAI
|
| 14 |
|
| 15 |
+
# For Hugging Face Spaces: assume data files are in the same directory as app.py
|
| 16 |
+
SCRIPT_DIR = Path(__file__).parent.resolve()
|
| 17 |
+
DATA_PATH = SCRIPT_DIR / 'ipl_knowledge_base.json'
|
| 18 |
+
CSV_PATH = SCRIPT_DIR / 'cricket_data.csv'
|
| 19 |
+
VECTOR_DIR = SCRIPT_DIR / 'vector_store'
|
| 20 |
VECTOR_DIR.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
FETCH_K = 8
|
| 23 |
CONTEXT_K = 4
|
| 24 |
COLLECTION_NAME = 'ipl_rag_ui'
|
|
|
|
| 44 |
|
| 45 |
@st.cache_data(show_spinner=False)
|
| 46 |
def load_kb() -> Dict[str, Any]:
|
| 47 |
+
"""Load knowledge base JSON file."""
|
| 48 |
+
# Ensure DATA_PATH is a Path object
|
| 49 |
+
data_path = Path(DATA_PATH)
|
| 50 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
| 51 |
return json.load(f)
|
| 52 |
|
| 53 |
|
| 54 |
@st.cache_data(show_spinner=False)
|
| 55 |
def load_stats_df() -> pd.DataFrame:
|
| 56 |
+
"""Load and preprocess CSV stats."""
|
| 57 |
+
csv_path = Path(CSV_PATH)
|
| 58 |
+
df = pd.read_csv(csv_path)
|
| 59 |
df = df[df['Year'] != 'No stats'].copy()
|
| 60 |
df['Year'] = pd.to_numeric(df['Year'])
|
| 61 |
numeric_cols = [
|
|
|
|
| 380 |
return second.choices[0].message.content, contexts
|
| 381 |
|
| 382 |
|
| 383 |
+
# Main execution
|
| 384 |
+
try:
|
| 385 |
+
kb = load_kb()
|
| 386 |
+
stats_df = load_stats_df()
|
| 387 |
+
stats_payload = stats_df.to_json(orient='records')
|
| 388 |
+
|
| 389 |
+
if st.sidebar.button('Build / refresh vector store', disabled=not api_key):
|
| 390 |
+
init_vector_store.clear()
|
| 391 |
+
st.sidebar.success('Rebuilt vector store')
|
| 392 |
+
|
| 393 |
+
if not api_key:
|
| 394 |
+
st.warning('Provide an OpenAI API key to run the agent.')
|
| 395 |
+
st.stop()
|
| 396 |
+
|
| 397 |
+
corpus, collection = init_vector_store(kb, stats_payload)
|
| 398 |
+
|
| 399 |
+
query = st.text_area('Ask anything about IPL 2024 (matches, players, venues, tactics)', height=140)
|
| 400 |
+
if st.button('Run query', disabled=not query.strip()):
|
| 401 |
+
with st.spinner('Calling vector DB + RAG agent...'):
|
| 402 |
+
answer, contexts = run_agent(query.strip(), kb, stats_df, collection, rerank_strategy)
|
| 403 |
+
st.success('Answer')
|
| 404 |
+
st.write(answer)
|
| 405 |
+
with st.expander('Retrieved context'):
|
| 406 |
+
for ctx in contexts:
|
| 407 |
+
sim = ctx.get('score', 0.0)
|
| 408 |
+
rerank_score = ctx.get('rerank_score')
|
| 409 |
+
suffix = f", rerank={rerank_score:.2f}" if rerank_score is not None else ''
|
| 410 |
+
st.markdown(f"**{ctx.get('type','doc')}::{ctx.get('id','unknown')}** (sim={sim:.2f}{suffix})")
|
| 411 |
+
st.write(ctx['text'])
|
| 412 |
+
st.divider()
|
| 413 |
+
else:
|
| 414 |
+
st.info('Enter a query and click run to test the pipeline.')
|
| 415 |
+
except FileNotFoundError as e:
|
| 416 |
+
st.error(f'Data file not found: {e}')
|
| 417 |
+
st.info(f'Looking for files in: {SCRIPT_DIR}')
|
| 418 |
+
st.info('Please ensure ipl_knowledge_base.json and cricket_data.csv are in the same directory as app.py')
|
| 419 |
+
except Exception as e:
|
| 420 |
+
st.error(f'Error loading application: {e}')
|
| 421 |
+
import traceback
|
| 422 |
+
st.code(traceback.format_exc())
|