Spaces:
Running
Running
| """ | |
| AgentBase Retrieval UI (limited interface). | |
| Author: Arastun Mammadli | |
| Date: [Current Date] | |
| """ | |
| from typing import List, Tuple | |
| from pathlib import Path | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import html | |
| from retrieval.models.bm25 import BM25Retriever | |
| from retrieval.models.sentence_bert import DenseRetriever | |
| from retrieval.utils import load_queries | |
| st.set_page_config(page_title="AgentBase", layout="wide") | |
| def load_retrievers(agentbase_path: str, index_configs: List[str]) -> Tuple[dict, dict]: | |
| bm25s = {} | |
| bges = {} | |
| toolrets = {} | |
| for idx_config in index_configs: | |
| bm25s[idx_config] = BM25Retriever(agentbase_path, index_config=idx_config) | |
| bges[idx_config] = DenseRetriever("BAAI/bge-large-en-v1.5", agentbase_path, index_config=idx_config) | |
| toolrets[idx_config] = DenseRetriever("mangopy/ToolRet-trained-bge-large-en-v1.5", agentbase_path, index_config=idx_config) | |
| return bm25s, bges, toolrets | |
| def load_agentbase_data(agentbase_path: str) -> pd.DataFrame: | |
| return pd.read_csv(agentbase_path) | |
| class AgentBaseUI: | |
| """ | |
| AgentBase Streamlit-based UI Components. | |
| """ | |
| def __init__(self, agentbase_path, platforms_path): | |
| self.agents_df = load_agentbase_data(agentbase_path) | |
| self.platforms_df = pd.read_csv(platforms_path) | |
| self.bm25s, self.bges, self.toolrets = load_retrievers(agentbase_path, index_configs=["v1", "naive"]) | |
| # selection options and defaults | |
| self.retrieval_models = ["bge-large", "toolret", "bm25"] | |
| self.selected_model = "bge-large" | |
| self.indexing_configs = ["v1", "naive"] | |
| self.indexing_config = "v1" | |
| def header_panel(self): | |
| st.title("AgentBase Retriever") | |
| st.write("A Large-Scale Agent Collection for Automated Agent Recommendation.") | |
| if "query" not in st.session_state: | |
| st.session_state.query = "" | |
| col1, col2, col3 = st.columns([4, 1, 1]) | |
| with col1: | |
| st.text_input("", placeholder="Type to search...", key="query") | |
| with col2: | |
| self.selected_model = st.selectbox("", self.retrieval_models, index=0) | |
| with col3: | |
| self.indexing_config = st.selectbox("", self.indexing_configs, index=0) | |
| _, col2 = st.columns([2, 1]) | |
| with col2: | |
| with st.expander("See explanation"): | |
| st.write(''' | |
| - **Retrieval Models**: | |
| - **BGE-Large**: a dense retrieval model. | |
| - **ToolRet**: a dense retrieval model fine-tuned for tool search. | |
| - **BM25**: a sparse retrieval model. | |
| - **Indexing Configurations**: | |
| - **v1**: using all columns with priority ordering (e.g., name, description come first). | |
| - **naive**: using agent name and description only. | |
| ''') | |
| def retrieval_panel(self): | |
| top_k = 5 | |
| if st.session_state.query: | |
| self.filtered_df = self.retrieve_agents(st.session_state.query, top_k) | |
| else: | |
| self.filtered_df = self.agents_df.copy() | |
| self.filtered_df["scores"] = 0.0 | |
| if len(self.filtered_df) > 0: | |
| key_columns = [ | |
| "agent_name", | |
| "platform_name", | |
| "agent_description", | |
| "agent_pricing", | |
| "base_model", | |
| "agent_url", | |
| "scores", | |
| ] | |
| if (self.filtered_df["scores"] == 0).all(): | |
| key_columns.remove("scores") | |
| display_df = self.filtered_df[key_columns].head(top_k) | |
| csv_text = self.filtered_df.head(top_k).to_csv(index=False) | |
| jsonl_text = self.filtered_df.head(top_k).to_json(orient="records", lines=True) | |
| # --- header row with actions --- | |
| col1, col2, col3 = st.columns([4, 1, 1]) | |
| with col1: | |
| st.write(f"Showing {top_k} of {len(self.agents_df)} agents") | |
| with col2: | |
| copy_button(csv_text, "Copy CSV") | |
| with col3: | |
| copy_button(jsonl_text, "Copy JSONL") | |
| agent_config = { | |
| "agent_name": st.column_config.TextColumn( | |
| "agent_name", width="medium" | |
| ), | |
| "agent_url": st.column_config.LinkColumn( | |
| "agent_url", display_text="Visit →" | |
| ), | |
| "agent_description": st.column_config.TextColumn( | |
| "agent_description", width="large" | |
| ), | |
| "agent_accessibility": st.column_config.TextColumn( | |
| "agent_accessibility", width="small" | |
| ), | |
| "agent_pricing": st.column_config.TextColumn( | |
| "agent_pricing", width="medium" | |
| ), | |
| "base_model": st.column_config.TextColumn( | |
| "base_model", width="medium" | |
| ), | |
| } | |
| st.dataframe( | |
| display_df, | |
| column_config=agent_config, | |
| use_container_width=True, | |
| hide_index=True, | |
| ) | |
| else: | |
| st.info("No agents match your search.") | |
| def retrieve_agents(self, query, top_k=100) -> pd.DataFrame: | |
| """ | |
| Returns a filtered dataframe with updated scores. | |
| Default maximum top_k of 100 | |
| """ | |
| if self.selected_model == 'bm25': | |
| res = self.bm25s[self.indexing_config].retrieve(query, top_k) | |
| elif self.selected_model == 'bge-large': | |
| res = self.bges[self.indexing_config].retrieve(query, top_k) | |
| elif self.selected_model == 'toolret': | |
| res = self.toolrets[self.indexing_config].retrieve(query, top_k) | |
| else: | |
| raise ValueError(f"Selected model must be one of {self.retrieval_models}") | |
| self.agents_df["scores"] = 0 # reset | |
| agent_ids, _ = zip(*res) | |
| filtered_df = self.agents_df.loc[self.agents_df.agent_id.isin(agent_ids)] | |
| for index, row in filtered_df.iterrows(): | |
| score = dict(res).get(row['agent_id'], 0) | |
| filtered_df.at[index, 'scores'] = score | |
| return filtered_df.sort_values(by="scores", ascending=False) | |
| def copy_button(text: str, label: str): | |
| safe = html.escape(text) | |
| st.components.v1.html( | |
| f""" | |
| <button | |
| onclick=" | |
| navigator.clipboard.writeText(`{safe}`); | |
| const btn = this; | |
| const original = btn.innerText; | |
| btn.innerText = 'Copied'; | |
| btn.classList.add('copied'); | |
| setTimeout(() => {{ | |
| btn.innerText = original; | |
| btn.classList.remove('copied'); | |
| }}, 900); | |
| " | |
| style=" | |
| width:100%; | |
| padding:10px 12px; | |
| border-radius:8px; | |
| border:1px solid #ddd; | |
| cursor:pointer; | |
| font-size:14px; | |
| background:#fff; | |
| transition: | |
| background-color 0.15s ease, | |
| transform 0.08s ease, | |
| box-shadow 0.08s ease; | |
| " | |
| onmouseover="this.style.background='#f6f7f9'" | |
| onmouseout="this.style.background='#fff'" | |
| onmousedown=" | |
| this.style.transform='scale(0.97)'; | |
| this.style.boxShadow='inset 0 2px 6px rgba(0,0,0,0.15)'; | |
| " | |
| onmouseup=" | |
| this.style.transform='scale(1)'; | |
| this.style.boxShadow='none'; | |
| " | |
| > | |
| {label} | |
| </button> | |
| """, | |
| height=52, | |
| ) | |
| if __name__ == "__main__": | |
| BASE_DIR = Path(__file__).resolve().parent | |
| agentbase_path = BASE_DIR / "../data/agentbase-v1.1.csv" | |
| platforms_path = BASE_DIR / "../data/platforms.csv" | |
| agentbaseui = AgentBaseUI(agentbase_path, platforms_path) | |
| agentbaseui.header_panel() | |
| agentbaseui.retrieval_panel() | |