import torch import torch.nn.functional as F from datetime import datetime, timezone from zoneinfo import ZoneInfo from safetensors.torch import load_model from models import FromZero, UNetSR from utils import ( MODEL_TYPE_CLASIFICACION, MODEL_TYPE_SR, calcular_psnr, multiclass_accuracy, calcular_puntaje, cargar_leaderboard, filtrar_leaderboard_por_tipo, guardar_registro_leaderboard, normalizar_nombre, normalizar_tipo_modelo, obtener_sha256, validar_datos_estudiante, ) TIMEZONE_RD = ZoneInfo("America/Santo_Domingo") ORDEN_LEADERBOARD_POR_DEFECTO = "mejores" ORDENES_LEADERBOARD_VALIDOS = { ORDEN_LEADERBOARD_POR_DEFECTO, "peores", "recientes", "antiguos", } def _formatear_timestamp_rd(timestamp): timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) if timestamp_dt.tzinfo is None: timestamp_dt = timestamp_dt.replace(tzinfo=timezone.utc) return timestamp_dt.astimezone(TIMEZONE_RD).strftime("%d/%m/%Y %I:%M:%S %p") def _parsear_timestamp(timestamp): timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) if timestamp_dt.tzinfo is None: timestamp_dt = timestamp_dt.replace(tzinfo=timezone.utc) return timestamp_dt def normalizar_orden_leaderboard(orden): orden_normalizado = (orden or ORDEN_LEADERBOARD_POR_DEFECTO).strip().lower() if orden_normalizado in ORDENES_LEADERBOARD_VALIDOS: return orden_normalizado return ORDEN_LEADERBOARD_POR_DEFECTO def cargar_evaluar_modelo_clasificacion(archivo, num_clases, test_dataloader): try: modelo = FromZero(num_clases) load_model(modelo, archivo) modelo.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") modelo.to(device) accuracy = 0 with torch.no_grad(): for imagenes, etiquetas in test_dataloader: imagenes = imagenes.to(device) etiquetas = etiquetas.to(device) predictions = modelo(imagenes) accuracy += multiclass_accuracy(predictions, etiquetas) accuracy = accuracy / len(test_dataloader) return accuracy except Exception as e: return f"Error: {str(e)}" def cargar_evaluar_modelo_sr(archivo, test_dataloader): try: modelo = UNetSR() load_model(modelo, archivo) modelo.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") modelo.to(device) total_psnr = 0.0 with torch.no_grad(): for imagenes_baja_res, imagenes_alta_res in test_dataloader: imagenes_baja_res = imagenes_baja_res.to(device) imagenes_alta_res = imagenes_alta_res.to(device) predictions = modelo(imagenes_baja_res) if predictions.shape[-2:] != imagenes_alta_res.shape[-2:]: predictions = F.interpolate( predictions, size=imagenes_alta_res.shape[-2:], mode="bilinear", align_corners=False, ) total_psnr += calcular_psnr(predictions, imagenes_alta_res) return total_psnr / len(test_dataloader) except Exception as e: return f"Error: {str(e)}" def _ordenar_registros(registros, model_type, orden=ORDEN_LEADERBOARD_POR_DEFECTO): orden = normalizar_orden_leaderboard(orden) if orden == "recientes": return sorted( registros, key=lambda r: ( _parsear_timestamp(r["timestamp"]), r["puntaje"], r.get("psnr", r.get("accuracy", 0.0)), ), reverse=True, ) if orden == "antiguos": return sorted( registros, key=lambda r: ( _parsear_timestamp(r["timestamp"]), r["puntaje"], r.get("psnr", r.get("accuracy", 0.0)), ), ) if orden == "peores": if model_type == MODEL_TYPE_SR: return sorted( registros, key=lambda r: (r["puntaje"], r.get("psnr", 0.0)), ) return sorted( registros, key=lambda r: (r["puntaje"], r.get("accuracy", 0.0)), ) if model_type == MODEL_TYPE_SR: return sorted( registros, key=lambda r: (r["puntaje"], r.get("psnr", 0.0)), reverse=True, ) return sorted( registros, key=lambda r: (r["puntaje"], r.get("accuracy", 0.0)), reverse=True, ) def _formatear_leaderboard( registros, model_type, orden=ORDEN_LEADERBOARD_POR_DEFECTO ): if not registros: return [] ordenados = _ordenar_registros(registros, model_type, orden) tabla = [] for entry in ordenados: sha_marcado = entry["sha256"] + (" *" if entry.get("duplicado") else "") duplicado = "Sí" if entry.get("duplicado") else "No" metric = ( f"{entry.get('psnr', 0.0):.2f} dB" if model_type == MODEL_TYPE_SR else f"{entry['accuracy_pct']:.2f}%" ) tabla.append( [ entry["nombre"], entry["matricula"], metric, entry["puntaje"], sha_marcado, duplicado, _formatear_timestamp_rd(entry["timestamp"]), ] ) return tabla def obtener_tablas_leaderboard( orden_clasificacion=ORDEN_LEADERBOARD_POR_DEFECTO, orden_sr=ORDEN_LEADERBOARD_POR_DEFECTO, ): registros = cargar_leaderboard() return ( _formatear_leaderboard( filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_CLASIFICACION), MODEL_TYPE_CLASIFICACION, orden_clasificacion, ), _formatear_leaderboard( filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_SR), MODEL_TYPE_SR, orden_sr, ), ) def evaluate_interface( nombre, matricula, model_file, model_type, orden_clasificacion, orden_sr, num_clases, test_dataloader, sr_dataloader, ): nombre = normalizar_nombre(nombre) matricula = (matricula or "").strip() model_type = normalizar_tipo_modelo(model_type) tabla_clasificacion, tabla_sr = obtener_tablas_leaderboard( orden_clasificacion, orden_sr ) error_validacion = validar_datos_estudiante(nombre, matricula) if error_validacion: return ( error_validacion, "", "", tabla_clasificacion, tabla_sr, ) if model_file is None: return ( "Por favor, carga un archivo .safetensor", "", "", tabla_clasificacion, tabla_sr, ) if not model_file.name.endswith(".safetensor") and not model_file.name.endswith( ".safetensors" ): return ( "Por favor, carga un archivo con extensión .safetensor o .safetensors", "", "", tabla_clasificacion, tabla_sr, ) sha256 = obtener_sha256(model_file.name) if model_type == MODEL_TYPE_SR: metric_value = cargar_evaluar_modelo_sr(model_file.name, sr_dataloader) metric_label = "PSNR promedio" score_label = "Puntaje SR" else: metric_value = cargar_evaluar_modelo_clasificacion( model_file.name, num_clases, test_dataloader ) metric_label = "Precisión del modelo" score_label = "Puntaje asignado" if isinstance(metric_value, str): return (metric_value, "", "", tabla_clasificacion, tabla_sr) registro = { "nombre": nombre, "matricula": matricula, "model_type": model_type, "sha256": sha256, "timestamp": datetime.now(timezone.utc).isoformat(), } if model_type == MODEL_TYPE_SR: registro["psnr"] = metric_value else: registro["accuracy"] = metric_value registro["accuracy_pct"] = metric_value * 100 registro["puntaje"] = calcular_puntaje(metric_value, model_type=model_type) registros = guardar_registro_leaderboard(registro) tabla_final_clasificacion = _formatear_leaderboard( filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_CLASIFICACION), MODEL_TYPE_CLASIFICACION, orden_clasificacion, ) tabla_final_sr = _formatear_leaderboard( filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_SR), MODEL_TYPE_SR, orden_sr, ) sha_marcado = sha256 + ( " *" if any( r["sha256"] == sha256 and r.get("duplicado") and normalizar_tipo_modelo(r.get("model_type")) == model_type for r in registros ) else "" ) metric_text = ( f"{metric_label}: {metric_value:.2f} dB" if model_type == MODEL_TYPE_SR else f"{metric_label}: {metric_value * 100:.2f}%" ) return ( metric_text, f"SHA256: {sha_marcado}", f"{score_label}: {registro['puntaje']} pts", tabla_final_clasificacion, tabla_final_sr, )