evaluador / evaluation.py
yoel
feat: orden configurable para leaderboards
e0f8389
import torch
import torch.nn.functional as F
from datetime import datetime, timezone
from zoneinfo import ZoneInfo
from safetensors.torch import load_model
from models import FromZero, UNetSR
from utils import (
MODEL_TYPE_CLASIFICACION,
MODEL_TYPE_SR,
calcular_psnr,
multiclass_accuracy,
calcular_puntaje,
cargar_leaderboard,
filtrar_leaderboard_por_tipo,
guardar_registro_leaderboard,
normalizar_nombre,
normalizar_tipo_modelo,
obtener_sha256,
validar_datos_estudiante,
)
TIMEZONE_RD = ZoneInfo("America/Santo_Domingo")
ORDEN_LEADERBOARD_POR_DEFECTO = "mejores"
ORDENES_LEADERBOARD_VALIDOS = {
ORDEN_LEADERBOARD_POR_DEFECTO,
"peores",
"recientes",
"antiguos",
}
def _formatear_timestamp_rd(timestamp):
timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
if timestamp_dt.tzinfo is None:
timestamp_dt = timestamp_dt.replace(tzinfo=timezone.utc)
return timestamp_dt.astimezone(TIMEZONE_RD).strftime("%d/%m/%Y %I:%M:%S %p")
def _parsear_timestamp(timestamp):
timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
if timestamp_dt.tzinfo is None:
timestamp_dt = timestamp_dt.replace(tzinfo=timezone.utc)
return timestamp_dt
def normalizar_orden_leaderboard(orden):
orden_normalizado = (orden or ORDEN_LEADERBOARD_POR_DEFECTO).strip().lower()
if orden_normalizado in ORDENES_LEADERBOARD_VALIDOS:
return orden_normalizado
return ORDEN_LEADERBOARD_POR_DEFECTO
def cargar_evaluar_modelo_clasificacion(archivo, num_clases, test_dataloader):
try:
modelo = FromZero(num_clases)
load_model(modelo, archivo)
modelo.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelo.to(device)
accuracy = 0
with torch.no_grad():
for imagenes, etiquetas in test_dataloader:
imagenes = imagenes.to(device)
etiquetas = etiquetas.to(device)
predictions = modelo(imagenes)
accuracy += multiclass_accuracy(predictions, etiquetas)
accuracy = accuracy / len(test_dataloader)
return accuracy
except Exception as e:
return f"Error: {str(e)}"
def cargar_evaluar_modelo_sr(archivo, test_dataloader):
try:
modelo = UNetSR()
load_model(modelo, archivo)
modelo.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelo.to(device)
total_psnr = 0.0
with torch.no_grad():
for imagenes_baja_res, imagenes_alta_res in test_dataloader:
imagenes_baja_res = imagenes_baja_res.to(device)
imagenes_alta_res = imagenes_alta_res.to(device)
predictions = modelo(imagenes_baja_res)
if predictions.shape[-2:] != imagenes_alta_res.shape[-2:]:
predictions = F.interpolate(
predictions,
size=imagenes_alta_res.shape[-2:],
mode="bilinear",
align_corners=False,
)
total_psnr += calcular_psnr(predictions, imagenes_alta_res)
return total_psnr / len(test_dataloader)
except Exception as e:
return f"Error: {str(e)}"
def _ordenar_registros(registros, model_type, orden=ORDEN_LEADERBOARD_POR_DEFECTO):
orden = normalizar_orden_leaderboard(orden)
if orden == "recientes":
return sorted(
registros,
key=lambda r: (
_parsear_timestamp(r["timestamp"]),
r["puntaje"],
r.get("psnr", r.get("accuracy", 0.0)),
),
reverse=True,
)
if orden == "antiguos":
return sorted(
registros,
key=lambda r: (
_parsear_timestamp(r["timestamp"]),
r["puntaje"],
r.get("psnr", r.get("accuracy", 0.0)),
),
)
if orden == "peores":
if model_type == MODEL_TYPE_SR:
return sorted(
registros,
key=lambda r: (r["puntaje"], r.get("psnr", 0.0)),
)
return sorted(
registros,
key=lambda r: (r["puntaje"], r.get("accuracy", 0.0)),
)
if model_type == MODEL_TYPE_SR:
return sorted(
registros,
key=lambda r: (r["puntaje"], r.get("psnr", 0.0)),
reverse=True,
)
return sorted(
registros,
key=lambda r: (r["puntaje"], r.get("accuracy", 0.0)),
reverse=True,
)
def _formatear_leaderboard(
registros, model_type, orden=ORDEN_LEADERBOARD_POR_DEFECTO
):
if not registros:
return []
ordenados = _ordenar_registros(registros, model_type, orden)
tabla = []
for entry in ordenados:
sha_marcado = entry["sha256"] + (" *" if entry.get("duplicado") else "")
duplicado = "Sí" if entry.get("duplicado") else "No"
metric = (
f"{entry.get('psnr', 0.0):.2f} dB"
if model_type == MODEL_TYPE_SR
else f"{entry['accuracy_pct']:.2f}%"
)
tabla.append(
[
entry["nombre"],
entry["matricula"],
metric,
entry["puntaje"],
sha_marcado,
duplicado,
_formatear_timestamp_rd(entry["timestamp"]),
]
)
return tabla
def obtener_tablas_leaderboard(
orden_clasificacion=ORDEN_LEADERBOARD_POR_DEFECTO,
orden_sr=ORDEN_LEADERBOARD_POR_DEFECTO,
):
registros = cargar_leaderboard()
return (
_formatear_leaderboard(
filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_CLASIFICACION),
MODEL_TYPE_CLASIFICACION,
orden_clasificacion,
),
_formatear_leaderboard(
filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_SR),
MODEL_TYPE_SR,
orden_sr,
),
)
def evaluate_interface(
nombre,
matricula,
model_file,
model_type,
orden_clasificacion,
orden_sr,
num_clases,
test_dataloader,
sr_dataloader,
):
nombre = normalizar_nombre(nombre)
matricula = (matricula or "").strip()
model_type = normalizar_tipo_modelo(model_type)
tabla_clasificacion, tabla_sr = obtener_tablas_leaderboard(
orden_clasificacion, orden_sr
)
error_validacion = validar_datos_estudiante(nombre, matricula)
if error_validacion:
return (
error_validacion,
"",
"",
tabla_clasificacion,
tabla_sr,
)
if model_file is None:
return (
"Por favor, carga un archivo .safetensor",
"",
"",
tabla_clasificacion,
tabla_sr,
)
if not model_file.name.endswith(".safetensor") and not model_file.name.endswith(
".safetensors"
):
return (
"Por favor, carga un archivo con extensión .safetensor o .safetensors",
"",
"",
tabla_clasificacion,
tabla_sr,
)
sha256 = obtener_sha256(model_file.name)
if model_type == MODEL_TYPE_SR:
metric_value = cargar_evaluar_modelo_sr(model_file.name, sr_dataloader)
metric_label = "PSNR promedio"
score_label = "Puntaje SR"
else:
metric_value = cargar_evaluar_modelo_clasificacion(
model_file.name, num_clases, test_dataloader
)
metric_label = "Precisión del modelo"
score_label = "Puntaje asignado"
if isinstance(metric_value, str):
return (metric_value, "", "", tabla_clasificacion, tabla_sr)
registro = {
"nombre": nombre,
"matricula": matricula,
"model_type": model_type,
"sha256": sha256,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if model_type == MODEL_TYPE_SR:
registro["psnr"] = metric_value
else:
registro["accuracy"] = metric_value
registro["accuracy_pct"] = metric_value * 100
registro["puntaje"] = calcular_puntaje(metric_value, model_type=model_type)
registros = guardar_registro_leaderboard(registro)
tabla_final_clasificacion = _formatear_leaderboard(
filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_CLASIFICACION),
MODEL_TYPE_CLASIFICACION,
orden_clasificacion,
)
tabla_final_sr = _formatear_leaderboard(
filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_SR),
MODEL_TYPE_SR,
orden_sr,
)
sha_marcado = sha256 + (
" *"
if any(
r["sha256"] == sha256
and r.get("duplicado")
and normalizar_tipo_modelo(r.get("model_type")) == model_type
for r in registros
)
else ""
)
metric_text = (
f"{metric_label}: {metric_value:.2f} dB"
if model_type == MODEL_TYPE_SR
else f"{metric_label}: {metric_value * 100:.2f}%"
)
return (
metric_text,
f"SHA256: {sha_marcado}",
f"{score_label}: {registro['puntaje']} pts",
tabla_final_clasificacion,
tabla_final_sr,
)