nathanael-fijalkow commited on
Commit
19d5912
·
1 Parent(s): ea5cabf

another fix?

Browse files
Files changed (2) hide show
  1. app.py +11 -11
  2. src/evaluate.py +21 -10
app.py CHANGED
@@ -281,7 +281,7 @@ def play_move(
281
  render_board_svg(current_fen if current_fen != "startpos" else None),
282
  current_fen,
283
  move_history,
284
- f"⚠️ Model generated illegal move: {move_token}",
285
  )
286
 
287
  except Exception as e:
@@ -289,7 +289,7 @@ def play_move(
289
  render_board_svg(),
290
  "startpos",
291
  "",
292
- f"Error: {str(e)}",
293
  )
294
 
295
 
@@ -543,7 +543,7 @@ with gr.Blocks(
543
  with gr.TabItem("🏆 Leaderboard"):
544
  gr.Markdown("### Current Rankings")
545
  leaderboard_html = gr.HTML(value=format_leaderboard_html(load_leaderboard()))
546
- refresh_btn = gr.Button("🔄 Refresh Leaderboard")
547
  refresh_btn.click(refresh_leaderboard, outputs=leaderboard_html)
548
 
549
  # Interactive Demo Tab
@@ -566,8 +566,8 @@ with gr.Blocks(
566
  )
567
 
568
  with gr.Row():
569
- play_btn = gr.Button("▶️ Model Move", variant="primary")
570
- reset_btn = gr.Button("🔄 Reset")
571
 
572
  status_text = gr.Textbox(label="Status", interactive=False)
573
 
@@ -618,7 +618,7 @@ with gr.Blocks(
618
  label="Number of Positions",
619
  )
620
 
621
- legal_btn = gr.Button("Run Legal Move Evaluation", variant="primary")
622
  legal_results = gr.Markdown()
623
 
624
  legal_btn.click(
@@ -675,7 +675,7 @@ with gr.Blocks(
675
  def verify_webhook_secret(secret: str) -> bool:
676
  """Verify the webhook secret from Hugging Face."""
677
  if not WEBHOOK_SECRET:
678
- print("⚠️ WEBHOOK_SECRET not set - skipping verification")
679
  return True
680
  return hmac.compare_digest(WEBHOOK_SECRET, secret)
681
 
@@ -714,10 +714,10 @@ def run_auto_evaluation(model_id: str):
714
  })
715
 
716
  save_leaderboard(leaderboard)
717
- print(f"Auto-evaluation complete for {model_id}: legal_rate={results.get('legal_rate_with_retry', 0):.1%}")
718
 
719
  except Exception as e:
720
- print(f"Auto-evaluation failed for {model_id}: {e}")
721
  import traceback
722
  traceback.print_exc()
723
 
@@ -729,7 +729,7 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
729
 
730
  # Verify secret
731
  if not verify_webhook_secret(secret):
732
- print("Webhook secret verification failed")
733
  return {"error": "Invalid secret"}, 403
734
 
735
  data = await request.json()
@@ -746,7 +746,7 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
746
  if event_type in ["create", "update"]:
747
  # Check if it's a chess model
748
  if "chess" in repo_name.lower():
749
- print(f"🎯 Queuing evaluation for chess model: {repo_name}")
750
  background_tasks.add_task(run_auto_evaluation, repo_name)
751
  return {"status": "evaluation_queued", "model": repo_name}
752
  else:
 
281
  render_board_svg(current_fen if current_fen != "startpos" else None),
282
  current_fen,
283
  move_history,
284
+ f"Model generated illegal move: {move_token}",
285
  )
286
 
287
  except Exception as e:
 
289
  render_board_svg(),
290
  "startpos",
291
  "",
292
+ f"Error: {str(e)}",
293
  )
294
 
295
 
 
543
  with gr.TabItem("🏆 Leaderboard"):
544
  gr.Markdown("### Current Rankings")
545
  leaderboard_html = gr.HTML(value=format_leaderboard_html(load_leaderboard()))
546
+ refresh_btn = gr.Button("Refresh Leaderboard")
547
  refresh_btn.click(refresh_leaderboard, outputs=leaderboard_html)
548
 
549
  # Interactive Demo Tab
 
566
  )
567
 
568
  with gr.Row():
569
+ play_btn = gr.Button("Model Move", variant="primary")
570
+ reset_btn = gr.Button("Reset")
571
 
572
  status_text = gr.Textbox(label="Status", interactive=False)
573
 
 
618
  label="Number of Positions",
619
  )
620
 
621
+ legal_btn = gr.Button("Run Legal Move Evaluation", variant="primary")
622
  legal_results = gr.Markdown()
623
 
624
  legal_btn.click(
 
675
  def verify_webhook_secret(secret: str) -> bool:
676
  """Verify the webhook secret from Hugging Face."""
677
  if not WEBHOOK_SECRET:
678
+ print("WEBHOOK_SECRET not set - skipping verification")
679
  return True
680
  return hmac.compare_digest(WEBHOOK_SECRET, secret)
681
 
 
714
  })
715
 
716
  save_leaderboard(leaderboard)
717
+ print(f"Auto-evaluation complete for {model_id}: legal_rate={results.get('legal_rate_with_retry', 0):.1%}")
718
 
719
  except Exception as e:
720
+ print(f"Auto-evaluation failed for {model_id}: {e}")
721
  import traceback
722
  traceback.print_exc()
723
 
 
729
 
730
  # Verify secret
731
  if not verify_webhook_secret(secret):
732
+ print("Webhook secret verification failed")
733
  return {"error": "Invalid secret"}, 403
734
 
735
  data = await request.json()
 
746
  if event_type in ["create", "update"]:
747
  # Check if it's a chess model
748
  if "chess" in repo_name.lower():
749
+ print(f"Queuing evaluation for chess model: {repo_name}")
750
  background_tasks.add_task(run_auto_evaluation, repo_name)
751
  return {"status": "evaluation_queued", "model": repo_name}
752
  else:
src/evaluate.py CHANGED
@@ -477,9 +477,11 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
477
  Returns:
478
  Tuple of (model, tokenizer).
479
  """
 
 
480
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
481
 
482
- # Import custom classes - this also triggers registration at module load
483
  try:
484
  from src.model import ChessConfig, ChessForCausalLM
485
  from src.tokenizer import ChessTokenizer
@@ -487,31 +489,40 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
487
  from .model import ChessConfig, ChessForCausalLM
488
  from .tokenizer import ChessTokenizer
489
 
490
- # Explicitly register to ensure it's done before loading
491
  try:
492
  AutoConfig.register("chess_transformer", ChessConfig)
493
  except ValueError:
494
- pass # Already registered
495
-
496
  try:
497
  AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
498
  except ValueError:
499
- pass # Already registered
500
 
501
- # Load using our local classes directly (most reliable)
502
  print(f"Loading model {model_id}...")
503
- config = ChessConfig.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
504
  model = ChessForCausalLM.from_pretrained(
505
  model_id,
506
  config=config,
507
  device_map=device,
508
- trust_remote_code=True,
509
  )
510
 
511
- # Load tokenizer - try custom class first, then generic
512
  try:
513
  tokenizer = ChessTokenizer.from_pretrained(model_id)
514
- except Exception:
 
515
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
516
 
517
  return model, tokenizer
 
477
  Returns:
478
  Tuple of (model, tokenizer).
479
  """
480
+ import json
481
+ from huggingface_hub import hf_hub_download
482
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
483
 
484
+ # Import custom classes
485
  try:
486
  from src.model import ChessConfig, ChessForCausalLM
487
  from src.tokenizer import ChessTokenizer
 
489
  from .model import ChessConfig, ChessForCausalLM
490
  from .tokenizer import ChessTokenizer
491
 
492
+ # Register BEFORE any from_pretrained calls
493
  try:
494
  AutoConfig.register("chess_transformer", ChessConfig)
495
  except ValueError:
496
+ pass
 
497
  try:
498
  AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
499
  except ValueError:
500
+ pass
501
 
 
502
  print(f"Loading model {model_id}...")
503
+
504
+ # Download and load config manually to avoid transformers auto-detection issues
505
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json")
506
+ with open(config_path, "r") as f:
507
+ config_dict = json.load(f)
508
+
509
+ # Remove model_type to avoid conflicts, instantiate our config directly
510
+ config_dict.pop("model_type", None)
511
+ config_dict.pop("architectures", None)
512
+ config = ChessConfig(**config_dict)
513
+
514
+ # Load model weights with our config
515
  model = ChessForCausalLM.from_pretrained(
516
  model_id,
517
  config=config,
518
  device_map=device,
 
519
  )
520
 
521
+ # Load tokenizer
522
  try:
523
  tokenizer = ChessTokenizer.from_pretrained(model_id)
524
+ except Exception as e:
525
+ print(f"ChessTokenizer failed ({e}), trying AutoTokenizer...")
526
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
527
 
528
  return model, tokenizer