AgentBase-Platform / pages /retrieval.py
Arastun's picture
fix: dataframe rendering error
65b059f
Raw
History Blame Contribute Delete
8 kB
"""
AgentBase Retrieval UI (limited interface).
Author: Arastun Mammadli
Date: [Current Date]
"""
from typing import List, Tuple
from pathlib import Path
import streamlit as st
import pandas as pd
import numpy as np
import html
from retrieval.models.bm25 import BM25Retriever
from retrieval.models.sentence_bert import DenseRetriever
from retrieval.utils import load_queries
st.set_page_config(page_title="AgentBase", layout="wide")
@st.cache_resource()
def load_retrievers(agentbase_path: str, index_configs: List[str]) -> Tuple[dict, dict]:
bm25s = {}
bges = {}
toolrets = {}
for idx_config in index_configs:
bm25s[idx_config] = BM25Retriever(agentbase_path, index_config=idx_config)
bges[idx_config] = DenseRetriever("BAAI/bge-large-en-v1.5", agentbase_path, index_config=idx_config)
toolrets[idx_config] = DenseRetriever("mangopy/ToolRet-trained-bge-large-en-v1.5", agentbase_path, index_config=idx_config)
return bm25s, bges, toolrets
@st.cache_resource()
def load_agentbase_data(agentbase_path: str) -> pd.DataFrame:
return pd.read_csv(agentbase_path)
class AgentBaseUI:
"""
AgentBase Streamlit-based UI Components.
"""
def __init__(self, agentbase_path, platforms_path):
self.agents_df = load_agentbase_data(agentbase_path)
self.platforms_df = pd.read_csv(platforms_path)
self.bm25s, self.bges, self.toolrets = load_retrievers(agentbase_path, index_configs=["v1", "naive"])
# selection options and defaults
self.retrieval_models = ["bge-large", "toolret", "bm25"]
self.selected_model = "bge-large"
self.indexing_configs = ["v1", "naive"]
self.indexing_config = "v1"
def header_panel(self):
st.title("AgentBase Retriever")
st.write("A Large-Scale Agent Collection for Automated Agent Recommendation.")
if "query" not in st.session_state:
st.session_state.query = ""
col1, col2, col3 = st.columns([4, 1, 1])
with col1:
st.text_input("", placeholder="Type to search...", key="query")
with col2:
self.selected_model = st.selectbox("", self.retrieval_models, index=0)
with col3:
self.indexing_config = st.selectbox("", self.indexing_configs, index=0)
_, col2 = st.columns([2, 1])
with col2:
with st.expander("See explanation"):
st.write('''
- **Retrieval Models**:
- **BGE-Large**: a dense retrieval model.
- **ToolRet**: a dense retrieval model fine-tuned for tool search.
- **BM25**: a sparse retrieval model.
- **Indexing Configurations**:
- **v1**: using all columns with priority ordering (e.g., name, description come first).
- **naive**: using agent name and description only.
''')
def retrieval_panel(self):
top_k = 5
if st.session_state.query:
self.filtered_df = self.retrieve_agents(st.session_state.query, top_k)
else:
self.filtered_df = self.agents_df.copy()
self.filtered_df["scores"] = 0.0
if len(self.filtered_df) > 0:
key_columns = [
"agent_name",
"platform_name",
"agent_description",
"agent_pricing",
"base_model",
"agent_url",
"scores",
]
if (self.filtered_df["scores"] == 0).all():
key_columns.remove("scores")
display_df = self.filtered_df[key_columns].head(top_k)
csv_text = self.filtered_df.head(top_k).to_csv(index=False)
jsonl_text = self.filtered_df.head(top_k).to_json(orient="records", lines=True)
# --- header row with actions ---
col1, col2, col3 = st.columns([4, 1, 1])
with col1:
st.write(f"Showing {top_k} of {len(self.agents_df)} agents")
with col2:
copy_button(csv_text, "Copy CSV")
with col3:
copy_button(jsonl_text, "Copy JSONL")
agent_config = {
"agent_name": st.column_config.TextColumn(
"agent_name", width="medium"
),
"agent_url": st.column_config.LinkColumn(
"agent_url", display_text="Visit →"
),
"agent_description": st.column_config.TextColumn(
"agent_description", width="large"
),
"agent_accessibility": st.column_config.TextColumn(
"agent_accessibility", width="small"
),
"agent_pricing": st.column_config.TextColumn(
"agent_pricing", width="medium"
),
"base_model": st.column_config.TextColumn(
"base_model", width="medium"
),
}
st.dataframe(
display_df,
column_config=agent_config,
use_container_width=True,
hide_index=True,
)
else:
st.info("No agents match your search.")
def retrieve_agents(self, query, top_k=100) -> pd.DataFrame:
"""
Returns a filtered dataframe with updated scores.
Default maximum top_k of 100
"""
if self.selected_model == 'bm25':
res = self.bm25s[self.indexing_config].retrieve(query, top_k)
elif self.selected_model == 'bge-large':
res = self.bges[self.indexing_config].retrieve(query, top_k)
elif self.selected_model == 'toolret':
res = self.toolrets[self.indexing_config].retrieve(query, top_k)
else:
raise ValueError(f"Selected model must be one of {self.retrieval_models}")
self.agents_df["scores"] = 0 # reset
agent_ids, _ = zip(*res)
filtered_df = self.agents_df.loc[self.agents_df.agent_id.isin(agent_ids)]
for index, row in filtered_df.iterrows():
score = dict(res).get(row['agent_id'], 0)
filtered_df.at[index, 'scores'] = score
return filtered_df.sort_values(by="scores", ascending=False)
def copy_button(text: str, label: str):
safe = html.escape(text)
st.components.v1.html(
f"""
<button
onclick="
navigator.clipboard.writeText(`{safe}`);
const btn = this;
const original = btn.innerText;
btn.innerText = 'Copied';
btn.classList.add('copied');
setTimeout(() => {{
btn.innerText = original;
btn.classList.remove('copied');
}}, 900);
"
style="
width:100%;
padding:10px 12px;
border-radius:8px;
border:1px solid #ddd;
cursor:pointer;
font-size:14px;
background:#fff;
transition:
background-color 0.15s ease,
transform 0.08s ease,
box-shadow 0.08s ease;
"
onmouseover="this.style.background='#f6f7f9'"
onmouseout="this.style.background='#fff'"
onmousedown="
this.style.transform='scale(0.97)';
this.style.boxShadow='inset 0 2px 6px rgba(0,0,0,0.15)';
"
onmouseup="
this.style.transform='scale(1)';
this.style.boxShadow='none';
"
>
{label}
</button>
""",
height=52,
)
if __name__ == "__main__":
BASE_DIR = Path(__file__).resolve().parent
agentbase_path = BASE_DIR / "../data/agentbase-v1.1.csv"
platforms_path = BASE_DIR / "../data/platforms.csv"
agentbaseui = AgentBaseUI(agentbase_path, platforms_path)
agentbaseui.header_panel()
agentbaseui.retrieval_panel()