evaluador / evaluation.py
yoel
Refactor: mejora la interfaz de evaluación agregando campos para nombre y matrícula, y actualiza la gestión del leaderboard
a2dd494
import torch
from datetime import datetime, timezone
from safetensors.torch import load_model
from models import FromZero
from utils import (
multiclass_accuracy,
calcular_puntaje,
cargar_leaderboard,
guardar_registro_leaderboard,
obtener_sha256,
)
def cargar_evaluar_modelo(archivo, num_clases, test_dataloader):
try:
modelo = FromZero(num_clases)
load_model(modelo, archivo)
modelo.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelo.to(device)
accuracy = 0
with torch.no_grad():
for imagenes, etiquetas in test_dataloader:
imagenes = imagenes.to(device)
etiquetas = etiquetas.to(device)
predictions = modelo(imagenes)
accuracy += multiclass_accuracy(predictions, etiquetas)
accuracy = accuracy / len(test_dataloader)
return accuracy
except Exception as e:
return f"Error: {str(e)}"
def _formatear_leaderboard(registros):
if not registros:
return []
ordenados = sorted(
registros,
key=lambda r: (r["puntaje"], r["accuracy"]),
reverse=True,
)
tabla = []
for entry in ordenados:
sha_marcado = entry["sha256"] + (" *" if entry.get("duplicado") else "")
duplicado = "Sí" if entry.get("duplicado") else "No"
tabla.append(
[
entry["nombre"],
entry["matricula"],
f"{entry['accuracy_pct']:.2f}%",
entry["puntaje"],
sha_marcado,
duplicado,
entry["timestamp"],
]
)
return tabla
def evaluate_interface(nombre, matricula, model_file, num_clases, test_dataloader):
nombre = (nombre or "").strip()
matricula = (matricula or "").strip()
tabla_lideres = _formatear_leaderboard(cargar_leaderboard())
if not nombre or not matricula:
return (
"Por favor, ingresa nombre y matrícula.",
"",
"",
tabla_lideres,
)
if model_file is None:
return ("Por favor, carga un archivo .safetensor", "", "", tabla_lideres)
if not model_file.name.endswith(".safetensor") and not model_file.name.endswith(
".safetensors"
):
return (
"Por favor, carga un archivo con extensión .safetensor o .safetensors",
"",
"",
tabla_lideres,
)
sha256 = obtener_sha256(model_file.name)
accuracy = cargar_evaluar_modelo(model_file.name, num_clases, test_dataloader)
if isinstance(accuracy, str):
return (accuracy, "", "", tabla_lideres)
puntaje = calcular_puntaje(accuracy)
accuracy_pct = accuracy * 100
registro = {
"nombre": nombre,
"matricula": matricula,
"accuracy": accuracy,
"accuracy_pct": accuracy_pct,
"puntaje": puntaje,
"sha256": sha256,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
registros = guardar_registro_leaderboard(registro)
tabla_final = _formatear_leaderboard(registros)
sha_marcado = sha256 + (
" *" if any(r["sha256"] == sha256 and r.get("duplicado") for r in registros) else ""
)
return (
f"Precisión del modelo: {accuracy_pct:.2f}%",
f"SHA256: {sha_marcado}",
f"Puntaje asignado: {puntaje} pts",
tabla_final,
)