a6687543's picture
Update app.py
fc87e93 verified
import os
import json
import requests
import glob
from pathlib import Path
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.repocard import metadata_load
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm.contrib.concurrent import thread_map
from utils import *
# Configuration for retrieval task leaderboard
SUBMISSION_FOLDER = "submission"
HF_TOKEN = os.environ.get("HF_TOKEN")
block = gr.Blocks()
api = HfApi(token=HF_TOKEN)
# Retrieval task metrics configuration
retrieval_metrics = [
{
"metric_name": "Hit Rate Click@50",
"metric_key": "hit_rate_click@50",
"description": "Hit rate for click predictions at top 50"
},
{
"metric_name": "Hit Rate A2C@50",
"metric_key": "hit_rate_A2C@50",
"description": "Hit rate for A2C predictions at top 50"
},
{
"metric_name": "Hit Rate Purchase@50",
"metric_key": "hit_rate_purchase@50",
"description": "Hit rate for purchase predictions at top 50"
}
]
# Main leaderboard configuration
leaderboard_config = {
"title": "πŸ† Retrieval Task Leaderboard πŸ†",
"description": "Leaderboard for retrieval task performance",
"metrics": retrieval_metrics
}
def restart():
print("RESTART")
api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard")
def load_submission_files():
"""Load all JSON submission files from the submission folder"""
submission_files = glob.glob(os.path.join(SUBMISSION_FOLDER, "*.json"))
submissions = []
for file_path in submission_files:
try:
with open(file_path, 'r') as f:
submission_data = json.load(f)
# Validate required fields
required_fields = ["user_id", "model_id", "hit_rate_click@50", "hit_rate_A2C@50", "hit_rate_purchase@50"]
if all(field in submission_data for field in required_fields):
submissions.append(submission_data)
else:
print(f"Warning: Invalid submission format in {file_path}")
except (json.JSONDecodeError, FileNotFoundError) as e:
print(f"Error reading {file_path}: {e}")
return submissions
def parse_submission_data(submission):
"""Parse a single submission and return formatted data"""
try:
# Convert string metrics to float, handle potential errors
click_rate = float(submission.get("hit_rate_click@50", 0))
a2c_rate = float(submission.get("hit_rate_A2C@50", 0))
purchase_rate = float(submission.get("hit_rate_purchase@50", 0))
return {
"User": submission.get("user_id", "Unknown"),
"Model": submission.get("model_id", "Unknown"),
"Dataset": submission.get("dataset_id", "Unknown"),
"Hit Rate Click@50": click_rate,
"Hit Rate A2C@50": a2c_rate,
"Hit Rate Purchase@50": purchase_rate,
"Comment": submission.get("comment", "")
}
except (ValueError, TypeError) as e:
print(f"Error parsing submission data: {e}")
return None
def update_leaderboard_from_submissions():
"""Update leaderboard data from JSON submissions"""
submissions = load_submission_files()
data = []
for submission in submissions:
parsed_data = parse_submission_data(submission)
if parsed_data:
data.append(parsed_data)
if not data:
# Create empty dataframe with correct columns if no submissions
return pd.DataFrame(columns=["User", "Model", "Dataset", "Hit Rate Click@50", "Hit Rate A2C@50", "Hit Rate Purchase@50", "Comment"])
df = pd.DataFrame(data)
# Sort by hit rate click@50 (descending) as default
df = df.sort_values(by='Hit Rate Click@50', ascending=False)
df.reset_index(drop=True, inplace=True)
df.insert(0, 'Ranking', range(1, len(df) + 1))
return df
def rank_dataframe(dataframe):
dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False)
if not 'Ranking' in dataframe.columns:
dataframe.insert(0, 'Ranking', [i for i in range(1,len(dataframe)+1)])
else:
dataframe['Ranking'] = [i for i in range(1,len(dataframe)+1)]
return dataframe
def get_leaderboard_data():
"""Get current leaderboard data from submissions"""
return update_leaderboard_from_submissions()
def refresh_leaderboard():
"""Simple function to refresh the leaderboard display"""
print("πŸ”„ Refreshing leaderboard...")
return get_leaderboard_data()
# run_update_dataset()
with block:
gr.Markdown("""
# πŸ† Retrieval Task Leaderboard πŸ†
This leaderboard tracks the performance of different models on retrieval tasks.
### How to Submit
Submit your results as a JSON file in the `submission` folder via pull request.
### Required JSON Format
```json
{
"user_id": "your_username",
"model_id": "your_model_name",
"hit_rate_click@50": "0.75",
"hit_rate_A2C@50": "0.68",
"hit_rate_purchase@50": "0.82",
"dataset_id": "your_dataset",
"comment": "Optional comment about your submission"
}
```
### How to Update After PR
**Currently, PR detection is NOT automated.** After a PR is merged:
Wait until the APP is rebuilt.
### Rankings
Currently ranked by "Hit Rate Click@50" (you can modify the sorting in the code)
""")
# Simple refresh button
refresh_button = gr.Button("πŸ”„ Refresh Leaderboard")
# Display leaderboard without Overall Score column
leaderboard_df = gr.Dataframe(
value=get_leaderboard_data(),
headers=["Ranking", "User", "Model", "Dataset", "Hit Rate Click@50", "Hit Rate A2C@50", "Hit Rate Purchase@50", "Comment"],
label="Current Leaderboard"
)
# Simple refresh functionality
refresh_button.click(refresh_leaderboard, outputs=leaderboard_df)
# Initialize the system
print("πŸš€ Starting Retrieval Task Leaderboard...")
# Setup background scheduler (optional, mainly for restart)
scheduler = BackgroundScheduler()
scheduler.add_job(restart, 'interval', seconds=21600) # Restart every 6 hours
scheduler.start()
print("βœ… System initialized successfully!")
print("πŸ“Š Leaderboard accessible at: http://127.0.0.1:7860")
print("⚠️ PR detection is NOT automated - restart manually after PR merges")
print("πŸ”„ Use the refresh button in the UI to update the leaderboard")
block.launch()