Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from datetime import datetime, timezone | |
| from zoneinfo import ZoneInfo | |
| from safetensors.torch import load_model | |
| from models import FromZero, UNetSR | |
| from utils import ( | |
| MODEL_TYPE_CLASIFICACION, | |
| MODEL_TYPE_SR, | |
| calcular_psnr, | |
| multiclass_accuracy, | |
| calcular_puntaje, | |
| cargar_leaderboard, | |
| filtrar_leaderboard_por_tipo, | |
| guardar_registro_leaderboard, | |
| normalizar_nombre, | |
| normalizar_tipo_modelo, | |
| obtener_sha256, | |
| validar_datos_estudiante, | |
| ) | |
| TIMEZONE_RD = ZoneInfo("America/Santo_Domingo") | |
| ORDEN_LEADERBOARD_POR_DEFECTO = "mejores" | |
| ORDENES_LEADERBOARD_VALIDOS = { | |
| ORDEN_LEADERBOARD_POR_DEFECTO, | |
| "peores", | |
| "recientes", | |
| "antiguos", | |
| } | |
| def _formatear_timestamp_rd(timestamp): | |
| timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) | |
| if timestamp_dt.tzinfo is None: | |
| timestamp_dt = timestamp_dt.replace(tzinfo=timezone.utc) | |
| return timestamp_dt.astimezone(TIMEZONE_RD).strftime("%d/%m/%Y %I:%M:%S %p") | |
| def _parsear_timestamp(timestamp): | |
| timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) | |
| if timestamp_dt.tzinfo is None: | |
| timestamp_dt = timestamp_dt.replace(tzinfo=timezone.utc) | |
| return timestamp_dt | |
| def normalizar_orden_leaderboard(orden): | |
| orden_normalizado = (orden or ORDEN_LEADERBOARD_POR_DEFECTO).strip().lower() | |
| if orden_normalizado in ORDENES_LEADERBOARD_VALIDOS: | |
| return orden_normalizado | |
| return ORDEN_LEADERBOARD_POR_DEFECTO | |
| def cargar_evaluar_modelo_clasificacion(archivo, num_clases, test_dataloader): | |
| try: | |
| modelo = FromZero(num_clases) | |
| load_model(modelo, archivo) | |
| modelo.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| modelo.to(device) | |
| accuracy = 0 | |
| with torch.no_grad(): | |
| for imagenes, etiquetas in test_dataloader: | |
| imagenes = imagenes.to(device) | |
| etiquetas = etiquetas.to(device) | |
| predictions = modelo(imagenes) | |
| accuracy += multiclass_accuracy(predictions, etiquetas) | |
| accuracy = accuracy / len(test_dataloader) | |
| return accuracy | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def cargar_evaluar_modelo_sr(archivo, test_dataloader): | |
| try: | |
| modelo = UNetSR() | |
| load_model(modelo, archivo) | |
| modelo.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| modelo.to(device) | |
| total_psnr = 0.0 | |
| with torch.no_grad(): | |
| for imagenes_baja_res, imagenes_alta_res in test_dataloader: | |
| imagenes_baja_res = imagenes_baja_res.to(device) | |
| imagenes_alta_res = imagenes_alta_res.to(device) | |
| predictions = modelo(imagenes_baja_res) | |
| if predictions.shape[-2:] != imagenes_alta_res.shape[-2:]: | |
| predictions = F.interpolate( | |
| predictions, | |
| size=imagenes_alta_res.shape[-2:], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| total_psnr += calcular_psnr(predictions, imagenes_alta_res) | |
| return total_psnr / len(test_dataloader) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def _ordenar_registros(registros, model_type, orden=ORDEN_LEADERBOARD_POR_DEFECTO): | |
| orden = normalizar_orden_leaderboard(orden) | |
| if orden == "recientes": | |
| return sorted( | |
| registros, | |
| key=lambda r: ( | |
| _parsear_timestamp(r["timestamp"]), | |
| r["puntaje"], | |
| r.get("psnr", r.get("accuracy", 0.0)), | |
| ), | |
| reverse=True, | |
| ) | |
| if orden == "antiguos": | |
| return sorted( | |
| registros, | |
| key=lambda r: ( | |
| _parsear_timestamp(r["timestamp"]), | |
| r["puntaje"], | |
| r.get("psnr", r.get("accuracy", 0.0)), | |
| ), | |
| ) | |
| if orden == "peores": | |
| if model_type == MODEL_TYPE_SR: | |
| return sorted( | |
| registros, | |
| key=lambda r: (r["puntaje"], r.get("psnr", 0.0)), | |
| ) | |
| return sorted( | |
| registros, | |
| key=lambda r: (r["puntaje"], r.get("accuracy", 0.0)), | |
| ) | |
| if model_type == MODEL_TYPE_SR: | |
| return sorted( | |
| registros, | |
| key=lambda r: (r["puntaje"], r.get("psnr", 0.0)), | |
| reverse=True, | |
| ) | |
| return sorted( | |
| registros, | |
| key=lambda r: (r["puntaje"], r.get("accuracy", 0.0)), | |
| reverse=True, | |
| ) | |
| def _formatear_leaderboard( | |
| registros, model_type, orden=ORDEN_LEADERBOARD_POR_DEFECTO | |
| ): | |
| if not registros: | |
| return [] | |
| ordenados = _ordenar_registros(registros, model_type, orden) | |
| tabla = [] | |
| for entry in ordenados: | |
| sha_marcado = entry["sha256"] + (" *" if entry.get("duplicado") else "") | |
| duplicado = "Sí" if entry.get("duplicado") else "No" | |
| metric = ( | |
| f"{entry.get('psnr', 0.0):.2f} dB" | |
| if model_type == MODEL_TYPE_SR | |
| else f"{entry['accuracy_pct']:.2f}%" | |
| ) | |
| tabla.append( | |
| [ | |
| entry["nombre"], | |
| entry["matricula"], | |
| metric, | |
| entry["puntaje"], | |
| sha_marcado, | |
| duplicado, | |
| _formatear_timestamp_rd(entry["timestamp"]), | |
| ] | |
| ) | |
| return tabla | |
| def obtener_tablas_leaderboard( | |
| orden_clasificacion=ORDEN_LEADERBOARD_POR_DEFECTO, | |
| orden_sr=ORDEN_LEADERBOARD_POR_DEFECTO, | |
| ): | |
| registros = cargar_leaderboard() | |
| return ( | |
| _formatear_leaderboard( | |
| filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_CLASIFICACION), | |
| MODEL_TYPE_CLASIFICACION, | |
| orden_clasificacion, | |
| ), | |
| _formatear_leaderboard( | |
| filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_SR), | |
| MODEL_TYPE_SR, | |
| orden_sr, | |
| ), | |
| ) | |
| def evaluate_interface( | |
| nombre, | |
| matricula, | |
| model_file, | |
| model_type, | |
| orden_clasificacion, | |
| orden_sr, | |
| num_clases, | |
| test_dataloader, | |
| sr_dataloader, | |
| ): | |
| nombre = normalizar_nombre(nombre) | |
| matricula = (matricula or "").strip() | |
| model_type = normalizar_tipo_modelo(model_type) | |
| tabla_clasificacion, tabla_sr = obtener_tablas_leaderboard( | |
| orden_clasificacion, orden_sr | |
| ) | |
| error_validacion = validar_datos_estudiante(nombre, matricula) | |
| if error_validacion: | |
| return ( | |
| error_validacion, | |
| "", | |
| "", | |
| tabla_clasificacion, | |
| tabla_sr, | |
| ) | |
| if model_file is None: | |
| return ( | |
| "Por favor, carga un archivo .safetensor", | |
| "", | |
| "", | |
| tabla_clasificacion, | |
| tabla_sr, | |
| ) | |
| if not model_file.name.endswith(".safetensor") and not model_file.name.endswith( | |
| ".safetensors" | |
| ): | |
| return ( | |
| "Por favor, carga un archivo con extensión .safetensor o .safetensors", | |
| "", | |
| "", | |
| tabla_clasificacion, | |
| tabla_sr, | |
| ) | |
| sha256 = obtener_sha256(model_file.name) | |
| if model_type == MODEL_TYPE_SR: | |
| metric_value = cargar_evaluar_modelo_sr(model_file.name, sr_dataloader) | |
| metric_label = "PSNR promedio" | |
| score_label = "Puntaje SR" | |
| else: | |
| metric_value = cargar_evaluar_modelo_clasificacion( | |
| model_file.name, num_clases, test_dataloader | |
| ) | |
| metric_label = "Precisión del modelo" | |
| score_label = "Puntaje asignado" | |
| if isinstance(metric_value, str): | |
| return (metric_value, "", "", tabla_clasificacion, tabla_sr) | |
| registro = { | |
| "nombre": nombre, | |
| "matricula": matricula, | |
| "model_type": model_type, | |
| "sha256": sha256, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| } | |
| if model_type == MODEL_TYPE_SR: | |
| registro["psnr"] = metric_value | |
| else: | |
| registro["accuracy"] = metric_value | |
| registro["accuracy_pct"] = metric_value * 100 | |
| registro["puntaje"] = calcular_puntaje(metric_value, model_type=model_type) | |
| registros = guardar_registro_leaderboard(registro) | |
| tabla_final_clasificacion = _formatear_leaderboard( | |
| filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_CLASIFICACION), | |
| MODEL_TYPE_CLASIFICACION, | |
| orden_clasificacion, | |
| ) | |
| tabla_final_sr = _formatear_leaderboard( | |
| filtrar_leaderboard_por_tipo(registros, MODEL_TYPE_SR), | |
| MODEL_TYPE_SR, | |
| orden_sr, | |
| ) | |
| sha_marcado = sha256 + ( | |
| " *" | |
| if any( | |
| r["sha256"] == sha256 | |
| and r.get("duplicado") | |
| and normalizar_tipo_modelo(r.get("model_type")) == model_type | |
| for r in registros | |
| ) | |
| else "" | |
| ) | |
| metric_text = ( | |
| f"{metric_label}: {metric_value:.2f} dB" | |
| if model_type == MODEL_TYPE_SR | |
| else f"{metric_label}: {metric_value * 100:.2f}%" | |
| ) | |
| return ( | |
| metric_text, | |
| f"SHA256: {sha_marcado}", | |
| f"{score_label}: {registro['puntaje']} pts", | |
| tabla_final_clasificacion, | |
| tabla_final_sr, | |
| ) | |