Spaces:
Sleeping
Sleeping
yoel
Refactor: mejora la interfaz de evaluación agregando campos para nombre y matrícula, y actualiza la gestión del leaderboard
a2dd494
| import torch | |
| from datetime import datetime, timezone | |
| from safetensors.torch import load_model | |
| from models import FromZero | |
| from utils import ( | |
| multiclass_accuracy, | |
| calcular_puntaje, | |
| cargar_leaderboard, | |
| guardar_registro_leaderboard, | |
| obtener_sha256, | |
| ) | |
| def cargar_evaluar_modelo(archivo, num_clases, test_dataloader): | |
| try: | |
| modelo = FromZero(num_clases) | |
| load_model(modelo, archivo) | |
| modelo.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| modelo.to(device) | |
| accuracy = 0 | |
| with torch.no_grad(): | |
| for imagenes, etiquetas in test_dataloader: | |
| imagenes = imagenes.to(device) | |
| etiquetas = etiquetas.to(device) | |
| predictions = modelo(imagenes) | |
| accuracy += multiclass_accuracy(predictions, etiquetas) | |
| accuracy = accuracy / len(test_dataloader) | |
| return accuracy | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def _formatear_leaderboard(registros): | |
| if not registros: | |
| return [] | |
| ordenados = sorted( | |
| registros, | |
| key=lambda r: (r["puntaje"], r["accuracy"]), | |
| reverse=True, | |
| ) | |
| tabla = [] | |
| for entry in ordenados: | |
| sha_marcado = entry["sha256"] + (" *" if entry.get("duplicado") else "") | |
| duplicado = "Sí" if entry.get("duplicado") else "No" | |
| tabla.append( | |
| [ | |
| entry["nombre"], | |
| entry["matricula"], | |
| f"{entry['accuracy_pct']:.2f}%", | |
| entry["puntaje"], | |
| sha_marcado, | |
| duplicado, | |
| entry["timestamp"], | |
| ] | |
| ) | |
| return tabla | |
| def evaluate_interface(nombre, matricula, model_file, num_clases, test_dataloader): | |
| nombre = (nombre or "").strip() | |
| matricula = (matricula or "").strip() | |
| tabla_lideres = _formatear_leaderboard(cargar_leaderboard()) | |
| if not nombre or not matricula: | |
| return ( | |
| "Por favor, ingresa nombre y matrícula.", | |
| "", | |
| "", | |
| tabla_lideres, | |
| ) | |
| if model_file is None: | |
| return ("Por favor, carga un archivo .safetensor", "", "", tabla_lideres) | |
| if not model_file.name.endswith(".safetensor") and not model_file.name.endswith( | |
| ".safetensors" | |
| ): | |
| return ( | |
| "Por favor, carga un archivo con extensión .safetensor o .safetensors", | |
| "", | |
| "", | |
| tabla_lideres, | |
| ) | |
| sha256 = obtener_sha256(model_file.name) | |
| accuracy = cargar_evaluar_modelo(model_file.name, num_clases, test_dataloader) | |
| if isinstance(accuracy, str): | |
| return (accuracy, "", "", tabla_lideres) | |
| puntaje = calcular_puntaje(accuracy) | |
| accuracy_pct = accuracy * 100 | |
| registro = { | |
| "nombre": nombre, | |
| "matricula": matricula, | |
| "accuracy": accuracy, | |
| "accuracy_pct": accuracy_pct, | |
| "puntaje": puntaje, | |
| "sha256": sha256, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| } | |
| registros = guardar_registro_leaderboard(registro) | |
| tabla_final = _formatear_leaderboard(registros) | |
| sha_marcado = sha256 + ( | |
| " *" if any(r["sha256"] == sha256 and r.get("duplicado") for r in registros) else "" | |
| ) | |
| return ( | |
| f"Precisión del modelo: {accuracy_pct:.2f}%", | |
| f"SHA256: {sha_marcado}", | |
| f"Puntaje asignado: {puntaje} pts", | |
| tabla_final, | |
| ) | |