File size: 3,487 Bytes
237774d
a2dd494
237774d
d57909e
a2dd494
 
 
 
 
 
 
237774d
 
d57909e
237774d
d57909e
237774d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2dd494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237774d
a2dd494
237774d
302b2b5
2a9e07e
 
a2dd494
 
 
 
 
 
237774d
a2dd494
d57909e
237774d
a2dd494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from datetime import datetime, timezone
from safetensors.torch import load_model
from models import FromZero
from utils import (
    multiclass_accuracy,
    calcular_puntaje,
    cargar_leaderboard,
    guardar_registro_leaderboard,
    obtener_sha256,
)


def cargar_evaluar_modelo(archivo, num_clases, test_dataloader):
    try:
        modelo = FromZero(num_clases)

        load_model(modelo, archivo)
        modelo.eval()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        modelo.to(device)
        accuracy = 0

        with torch.no_grad():
            for imagenes, etiquetas in test_dataloader:
                imagenes = imagenes.to(device)
                etiquetas = etiquetas.to(device)
                predictions = modelo(imagenes)
                accuracy += multiclass_accuracy(predictions, etiquetas)

        accuracy = accuracy / len(test_dataloader)
        return accuracy
    except Exception as e:
        return f"Error: {str(e)}"


def _formatear_leaderboard(registros):
    if not registros:
        return []
    ordenados = sorted(
        registros,
        key=lambda r: (r["puntaje"], r["accuracy"]),
        reverse=True,
    )
    tabla = []
    for entry in ordenados:
        sha_marcado = entry["sha256"] + (" *" if entry.get("duplicado") else "")
        duplicado = "Sí" if entry.get("duplicado") else "No"
        tabla.append(
            [
                entry["nombre"],
                entry["matricula"],
                f"{entry['accuracy_pct']:.2f}%",
                entry["puntaje"],
                sha_marcado,
                duplicado,
                entry["timestamp"],
            ]
        )
    return tabla


def evaluate_interface(nombre, matricula, model_file, num_clases, test_dataloader):
    nombre = (nombre or "").strip()
    matricula = (matricula or "").strip()
    tabla_lideres = _formatear_leaderboard(cargar_leaderboard())

    if not nombre or not matricula:
        return (
            "Por favor, ingresa nombre y matrícula.",
            "",
            "",
            tabla_lideres,
        )

    if model_file is None:
        return ("Por favor, carga un archivo .safetensor", "", "", tabla_lideres)

    if not model_file.name.endswith(".safetensor") and not model_file.name.endswith(
        ".safetensors"
    ):
        return (
            "Por favor, carga un archivo con extensión .safetensor o .safetensors",
            "",
            "",
            tabla_lideres,
        )

    sha256 = obtener_sha256(model_file.name)
    accuracy = cargar_evaluar_modelo(model_file.name, num_clases, test_dataloader)

    if isinstance(accuracy, str):
        return (accuracy, "", "", tabla_lideres)

    puntaje = calcular_puntaje(accuracy)
    accuracy_pct = accuracy * 100
    registro = {
        "nombre": nombre,
        "matricula": matricula,
        "accuracy": accuracy,
        "accuracy_pct": accuracy_pct,
        "puntaje": puntaje,
        "sha256": sha256,
        "timestamp": datetime.now(timezone.utc).isoformat(),
    }
    registros = guardar_registro_leaderboard(registro)
    tabla_final = _formatear_leaderboard(registros)
    sha_marcado = sha256 + (
        " *" if any(r["sha256"] == sha256 and r.get("duplicado") for r in registros) else ""
    )

    return (
        f"Precisión del modelo: {accuracy_pct:.2f}%",
        f"SHA256: {sha_marcado}",
        f"Puntaje asignado: {puntaje} pts",
        tabla_final,
    )