nathanael-fijalkow commited on
Commit
88fbdea
·
1 Parent(s): 8a7719b

uniformize local and server-side evaluation

Browse files
Files changed (2) hide show
  1. app.py +11 -108
  2. src/evaluate.py +15 -1
app.py CHANGED
@@ -13,8 +13,6 @@ less than 1M parameters! This is approximately the number of neurons of a honey
13
  Leaderboard data is stored in a private HuggingFace dataset for persistence.
14
  """
15
 
16
- import hashlib
17
- import hmac
18
  import io
19
  import os
20
  import sys
@@ -24,17 +22,16 @@ from typing import Optional
24
 
25
  import gradio as gr
26
  import pandas as pd
27
- from fastapi import FastAPI, Request, BackgroundTasks
28
-
29
- # Create FastAPI app for webhook
30
- fastapi_app = FastAPI()
31
 
32
  # Configuration
33
  ORGANIZATION = os.environ.get("HF_ORGANIZATION", "LLM-course")
34
  LEADERBOARD_DATASET = os.environ.get("LEADERBOARD_DATASET", f"{ORGANIZATION}/chess-challenge-leaderboard")
35
  LEADERBOARD_FILENAME = "leaderboard.csv"
36
  HF_TOKEN = os.environ.get("HF_TOKEN") # Required for private dataset access
37
- WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET", "") # For webhook verification
 
 
 
38
 
39
  STOCKFISH_LEVELS = {
40
  "Beginner (Level 0)": 0,
@@ -342,7 +339,6 @@ def evaluate_legal_moves(
342
  progress: gr.Progress = gr.Progress(),
343
  ) -> str:
344
  """Evaluate a model's legal move generation."""
345
- n_positions = 500 # Fixed number of positions
346
  try:
347
  import sys
348
  sys.path.insert(0, str(Path(__file__).parent))
@@ -359,8 +355,12 @@ def evaluate_legal_moves(
359
  stockfish_level=1, # Not used for legal move eval
360
  )
361
 
362
- progress(0.2, desc=f"Testing {n_positions} positions...")
363
- results = evaluator.evaluate_legal_moves(n_positions=n_positions, verbose=False)
 
 
 
 
364
 
365
  # Extract user_id from model's README (submitted by field)
366
  user_id = get_model_submitter(model_id)
@@ -756,102 +756,5 @@ with gr.Blocks(
756
  refresh_btn.click(refresh_leaderboard, outputs=leaderboard_html)
757
 
758
 
759
- # =============================================================================
760
- # WEBHOOK HANDLERS FOR AUTOMATIC EVALUATION
761
- # =============================================================================
762
-
763
- def verify_webhook_secret(secret: str) -> bool:
764
- """Verify the webhook secret from Hugging Face."""
765
- if not WEBHOOK_SECRET:
766
- print("WEBHOOK_SECRET not set - skipping verification")
767
- return True
768
- return hmac.compare_digest(WEBHOOK_SECRET, secret)
769
-
770
-
771
- def run_auto_evaluation(model_id: str):
772
- """Run model evaluation in background after webhook trigger."""
773
- try:
774
- print(f"🚀 Auto-evaluating new model: {model_id}")
775
-
776
- # Import evaluation functions
777
- sys.path.insert(0, str(Path(__file__).parent))
778
- from src.evaluate import ChessEvaluator, load_model_from_hub
779
-
780
- # Load model
781
- model, tokenizer = load_model_from_hub(model_id)
782
-
783
- # Run legal moves evaluation (quick first pass)
784
- evaluator = ChessEvaluator(
785
- model=model,
786
- tokenizer=tokenizer,
787
- stockfish_level=1,
788
- )
789
- results = evaluator.evaluate_legal_moves(n_positions=100, verbose=True)
790
-
791
- # Update leaderboard
792
- leaderboard = load_leaderboard()
793
- entry = next((e for e in leaderboard if e["model_id"] == model_id), None)
794
- if entry is None:
795
- entry = {"model_id": model_id}
796
- leaderboard.append(entry)
797
-
798
- entry.update({
799
- "legal_rate": results.get("legal_rate_with_retry", 0),
800
- "legal_rate_first_try": results.get("legal_rate_first_try", 0),
801
- "last_updated": datetime.now().strftime("%Y-%m-%d %H:%M"),
802
- })
803
-
804
- save_leaderboard(leaderboard)
805
- print(f"Auto-evaluation complete for {model_id}: legal_rate={results.get('legal_rate_with_retry', 0):.1%}")
806
-
807
- except Exception as e:
808
- print(f"Auto-evaluation failed for {model_id}: {e}")
809
- import traceback
810
- traceback.print_exc()
811
-
812
-
813
- @fastapi_app.post("/webhook")
814
- async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
815
- """Handle incoming webhooks from Hugging Face."""
816
- secret = request.headers.get("X-Webhook-Secret", "")
817
-
818
- # Verify secret
819
- if not verify_webhook_secret(secret):
820
- print("Webhook secret verification failed")
821
- return {"error": "Invalid secret"}, 403
822
-
823
- data = await request.json()
824
- event = data.get("event", {})
825
- event_type = event.get("action")
826
- repo = data.get("repo", {})
827
- repo_type = repo.get("type")
828
- repo_name = repo.get("name")
829
-
830
- print(f"📥 Webhook received: {event_type} for {repo_type}/{repo_name}")
831
-
832
- # Only process model creation/updates in our organization
833
- if repo_type == "model" and repo_name and repo_name.startswith(f"{ORGANIZATION}/"):
834
- if event_type in ["create", "update"]:
835
- # Check if it's a chess model
836
- if "chess" in repo_name.lower():
837
- print(f"Queuing evaluation for chess model: {repo_name}")
838
- background_tasks.add_task(run_auto_evaluation, repo_name)
839
- return {"status": "evaluation_queued", "model": repo_name}
840
- else:
841
- print(f"⏭️ Skipping non-chess model: {repo_name}")
842
-
843
- return {"status": "ignored"}
844
-
845
-
846
- @fastapi_app.get("/health")
847
- async def health_check():
848
- """Health check endpoint."""
849
- return {"status": "healthy", "organization": ORGANIZATION}
850
-
851
-
852
- # Mount Gradio app to FastAPI
853
- fastapi_app = gr.mount_gradio_app(fastapi_app, demo, path="/")
854
-
855
  if __name__ == "__main__":
856
- import uvicorn
857
- uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
 
13
  Leaderboard data is stored in a private HuggingFace dataset for persistence.
14
  """
15
 
 
 
16
  import io
17
  import os
18
  import sys
 
22
 
23
  import gradio as gr
24
  import pandas as pd
 
 
 
 
25
 
26
  # Configuration
27
  ORGANIZATION = os.environ.get("HF_ORGANIZATION", "LLM-course")
28
  LEADERBOARD_DATASET = os.environ.get("LEADERBOARD_DATASET", f"{ORGANIZATION}/chess-challenge-leaderboard")
29
  LEADERBOARD_FILENAME = "leaderboard.csv"
30
  HF_TOKEN = os.environ.get("HF_TOKEN") # Required for private dataset access
31
+
32
+ # Evaluation settings
33
+ EVAL_SEED = 42
34
+ EVAL_N_POSITIONS = 500
35
 
36
  STOCKFISH_LEVELS = {
37
  "Beginner (Level 0)": 0,
 
339
  progress: gr.Progress = gr.Progress(),
340
  ) -> str:
341
  """Evaluate a model's legal move generation."""
 
342
  try:
343
  import sys
344
  sys.path.insert(0, str(Path(__file__).parent))
 
355
  stockfish_level=1, # Not used for legal move eval
356
  )
357
 
358
+ progress(0.2, desc=f"Testing {EVAL_N_POSITIONS} positions...")
359
+ results = evaluator.evaluate_legal_moves(
360
+ n_positions=EVAL_N_POSITIONS,
361
+ verbose=False,
362
+ seed=EVAL_SEED,
363
+ )
364
 
365
  # Extract user_id from model's README (submitted by field)
366
  user_id = get_model_submitter(model_id)
 
756
  refresh_btn.click(refresh_leaderboard, outputs=leaderboard_html)
757
 
758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  if __name__ == "__main__":
760
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
src/evaluate.py CHANGED
@@ -578,6 +578,7 @@ class ChessEvaluator:
578
  n_positions: int = 1000,
579
  temperature: float = 0.7,
580
  verbose: bool = True,
 
581
  ) -> dict:
582
  """
583
  Evaluate the model's ability to generate legal moves.
@@ -589,10 +590,15 @@ class ChessEvaluator:
589
  n_positions: Number of positions to test.
590
  temperature: Sampling temperature.
591
  verbose: Whether to print progress.
 
592
 
593
  Returns:
594
  Dictionary with legal move statistics.
595
  """
 
 
 
 
596
  results = {
597
  "total_positions": 0,
598
  "legal_first_try": 0,
@@ -800,6 +806,10 @@ def main():
800
  "--n_positions", type=int, default=500,
801
  help="Number of positions for legal move evaluation"
802
  )
 
 
 
 
803
  parser.add_argument(
804
  "--n_games", type=int, default=100,
805
  help="Number of games to play for win rate evaluation"
@@ -828,7 +838,10 @@ def main():
828
  from src.model import ChessConfig, ChessForCausalLM
829
 
830
  tokenizer = ChessTokenizer.from_pretrained(args.model_path)
831
- model = AutoModelForCausalLM.from_pretrained(args.model_path)
 
 
 
832
  else:
833
  # Assume Hugging Face model ID (or invalid path)
834
  if args.model_path.startswith(".") or args.model_path.startswith("/"):
@@ -858,6 +871,7 @@ def main():
858
  n_positions=args.n_positions,
859
  temperature=args.temperature,
860
  verbose=True,
 
861
  )
862
 
863
  print("\n" + "-" * 40)
 
578
  n_positions: int = 1000,
579
  temperature: float = 0.7,
580
  verbose: bool = True,
581
+ seed: int = 42,
582
  ) -> dict:
583
  """
584
  Evaluate the model's ability to generate legal moves.
 
590
  n_positions: Number of positions to test.
591
  temperature: Sampling temperature.
592
  verbose: Whether to print progress.
593
+ seed: Random seed for reproducibility.
594
 
595
  Returns:
596
  Dictionary with legal move statistics.
597
  """
598
+ # Set random seed for reproducibility
599
+ random.seed(seed)
600
+ torch.manual_seed(seed)
601
+
602
  results = {
603
  "total_positions": 0,
604
  "legal_first_try": 0,
 
806
  "--n_positions", type=int, default=500,
807
  help="Number of positions for legal move evaluation"
808
  )
809
+ parser.add_argument(
810
+ "--seed", type=int, default=42,
811
+ help="Random seed for reproducibility"
812
+ )
813
  parser.add_argument(
814
  "--n_games", type=int, default=100,
815
  help="Number of games to play for win rate evaluation"
 
838
  from src.model import ChessConfig, ChessForCausalLM
839
 
840
  tokenizer = ChessTokenizer.from_pretrained(args.model_path)
841
+ model = AutoModelForCausalLM.from_pretrained(
842
+ args.model_path,
843
+ device_map="auto",
844
+ )
845
  else:
846
  # Assume Hugging Face model ID (or invalid path)
847
  if args.model_path.startswith(".") or args.model_path.startswith("/"):
 
871
  n_positions=args.n_positions,
872
  temperature=args.temperature,
873
  verbose=True,
874
+ seed=args.seed,
875
  )
876
 
877
  print("\n" + "-" * 40)