|
|
import gradio as gr |
|
|
import gymnasium as gym |
|
|
from stable_baselines3 import PPO |
|
|
from huggingface_hub import HfApi, snapshot_download, login |
|
|
import pandas as pd |
|
|
import os |
|
|
import shutil |
|
|
import time |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
REQUESTS_DATASET = "gberseth/rl-leaderboard-requests" |
|
|
RESULTS_DATASET = "gberseth/rl-leaderboard-results" |
|
|
ENV_NAME = "CartPole-v1" |
|
|
EVAL_EPISODES = 10 |
|
|
|
|
|
|
|
|
login(token=HF_TOKEN) |
|
|
api = HfApi() |
|
|
|
|
|
def evaluate_policy(model_id): |
|
|
""" |
|
|
Downloads a PPO model from HF Hub, runs it in Gym, returns mean reward. |
|
|
""" |
|
|
print(f"Starting evaluation for: {model_id}") |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
repo_path = snapshot_download(repo_id=model_id, allow_patterns=["*.zip"]) |
|
|
|
|
|
|
|
|
model_file = None |
|
|
for root, dirs, files in os.walk(repo_path): |
|
|
for file in files: |
|
|
if file.endswith(".zip"): |
|
|
model_file = os.path.join(root, file) |
|
|
break |
|
|
|
|
|
if not model_file: |
|
|
return None, "Error: No .zip model file found in repo." |
|
|
|
|
|
|
|
|
|
|
|
model = PPO.load(model_file) |
|
|
|
|
|
|
|
|
env = gym.make(ENV_NAME) |
|
|
total_rewards = [] |
|
|
|
|
|
for i in range(EVAL_EPISODES): |
|
|
obs, _ = env.reset() |
|
|
done = False |
|
|
truncated = False |
|
|
episode_reward = 0 |
|
|
|
|
|
while not (done or truncated): |
|
|
|
|
|
action, _ = model.predict(obs, deterministic=True) |
|
|
obs, reward, done, truncated, _ = env.step(action) |
|
|
episode_reward += reward |
|
|
|
|
|
total_rewards.append(episode_reward) |
|
|
|
|
|
mean_reward = sum(total_rewards) / len(total_rewards) |
|
|
env.close() |
|
|
|
|
|
return mean_reward, "Success" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Evaluation failed: {e}") |
|
|
return None, str(e) |
|
|
|
|
|
def run_evaluation_loop(): |
|
|
""" |
|
|
Main loop: Pulls requests, checks for 'Pending', evaluates, updates datasets. |
|
|
""" |
|
|
print("Checking for new submissions...") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
requests_df = pd.read_csv(f"hf://datasets/{REQUESTS_DATASET}/requests.csv") |
|
|
except Exception: |
|
|
|
|
|
print("Requests dataset not found or empty.") |
|
|
return "No requests found." |
|
|
|
|
|
|
|
|
|
|
|
pending_rows = requests_df[requests_df["status"] == "Pending"] |
|
|
|
|
|
if len(pending_rows) == 0: |
|
|
return "No pending submissions." |
|
|
|
|
|
|
|
|
row_index = pending_rows.index[0] |
|
|
model_id = pending_rows.loc[row_index, "model_id"] |
|
|
|
|
|
print(f"Evaluating {model_id}...") |
|
|
|
|
|
|
|
|
score, status_msg = evaluate_policy(model_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requests_df.loc[row_index, "status"] = "Done" if score is not None else "Failed" |
|
|
|
|
|
|
|
|
if score is not None: |
|
|
new_result = { |
|
|
"model_id": model_id, |
|
|
"mean_reward": score, |
|
|
"status": "Success" |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
results_df = pd.read_csv(f"hf://datasets/{RESULTS_DATASET}/results.csv") |
|
|
except: |
|
|
results_df = pd.DataFrame(columns=["model_id", "mean_reward", "status"]) |
|
|
|
|
|
|
|
|
results_df = pd.concat([results_df, pd.DataFrame([new_result])], ignore_index=True) |
|
|
|
|
|
|
|
|
results_df.to_csv("results.csv", index=False) |
|
|
api.upload_file( |
|
|
path_or_fileobj="results.csv", |
|
|
path_in_repo="results.csv", |
|
|
repo_id=RESULTS_DATASET, |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
|
|
|
requests_df.to_csv("requests.csv", index=False) |
|
|
api.upload_file( |
|
|
path_or_fileobj="requests.csv", |
|
|
path_in_repo="requests.csv", |
|
|
repo_id=REQUESTS_DATASET, |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
return f"Processed {model_id}: Score {score}" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# RL Evaluation Backend") |
|
|
gr.Markdown("This space runs in the background to evaluate new submissions.") |
|
|
|
|
|
|
|
|
eval_btn = gr.Button("Run Evaluator Now") |
|
|
output = gr.Textbox(label="Logs") |
|
|
|
|
|
eval_btn.click(fn=run_evaluation_loop, outputs=output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch() |
|
|
|