evaluador / batch_eval.py
yoel
feat: añadir script batch para evaluar modelos
32ff46c
#!/usr/bin/env python3
"""Batch evaluator: recorre una carpeta (o archivo) y evalúa modelos.
- Por defecto evalúa clasificación
- Si se pasa --sr evalúa super-resolución
- Genera un archivo JSON en la misma ruta con los resultados y pares de alta similitud
"""
import argparse
import json
import os
import sys
import math
from datetime import datetime
import torch
import torch.nn.functional as F
from safetensors.torch import load_model
from utils import (
cargar_etiquetas,
obtener_sha256,
calcular_puntaje,
MODEL_TYPE_CLASIFICACION,
MODEL_TYPE_SR,
)
import dataset
import evaluation
from models import FromZero, UNetSR
SIMILARITY_DEFAULT_THRESHOLD = 0.95
def find_model_files(path):
if os.path.isfile(path):
if path.lower().endswith((".safetensor", ".safetensors")):
return [os.path.abspath(path)]
return []
files = []
for name in sorted(os.listdir(path)):
if name.lower().endswith((".safetensor", ".safetensors")):
files.append(os.path.join(path, name))
return files
def build_dataloaders():
etiquetas, num_classes, codigo = cargar_etiquetas()
test_dl, sr_dl = dataset.cargar_datasets(codigo)
return num_classes, test_dl, sr_dl
def extract_normalized_vector(model_cls, model_path, num_classes=None):
"""Carga el modelo en la arquitectura dada y devuelve su vector L2-normalizado.
Devuelve (vector (1D torch.float32) | None, error_str | None, length_of_vector)
"""
try:
model = model_cls(num_classes) if num_classes is not None else model_cls()
except TypeError:
# UNetSR() toma no args, FromZero requiere num_classes
model = model_cls()
try:
load_model(model, model_path)
model.to("cpu")
parts = []
for p in model.parameters():
v = p.detach().cpu().view(-1).float()
if v.numel() > 0:
parts.append(v)
if not parts:
return None, "Modelo sin parámetros", 0
vec = torch.cat(parts)
norm = float(vec.norm().item())
if norm > 0:
vec = vec / norm
return vec, None, vec.numel()
except Exception as e:
return None, str(e), 0
finally:
# Liberar referencias
try:
del model
except Exception:
pass
# best-effort
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
except Exception:
pass
def pairwise_similarities(file_vectors_by_len, threshold):
pairs = []
for length, items in file_vectors_by_len.items():
if len(items) < 2:
continue
# items: list of (index, filepath, vector)
idxs = [it[0] for it in items]
paths = [it[1] for it in items]
vecs = [it[2] for it in items]
# stack
try:
stack = torch.stack(vecs) # (N, L)
# since vectors are normalized, dot product = cosine similarity
sim_matrix = stack @ stack.t()
n = sim_matrix.shape[0]
for i in range(n):
for j in range(i + 1, n):
sim = float(sim_matrix[i, j].item())
if sim >= threshold:
pairs.append({
"file_a": paths[i],
"file_b": paths[j],
"similarity": sim,
})
except Exception:
# fallback: compute pairwise with cosine_similarity
for i in range(len(vecs)):
for j in range(i + 1, len(vecs)):
try:
sim = float(F.cosine_similarity(vecs[i].unsqueeze(0), vecs[j].unsqueeze(0), dim=1).item())
if sim >= threshold:
pairs.append({
"file_a": paths[i],
"file_b": paths[j],
"similarity": sim,
})
except Exception:
continue
return pairs
def main():
parser = argparse.ArgumentParser(description="Batch evaluation de modelos (.safetensor)")
parser.add_argument("path", help="Archivo .safetensor o carpeta con modelos")
parser.add_argument("--sr", action="store_true", help="Evaluar como super-resolución (SR)")
parser.add_argument("--threshold", type=float, default=SIMILARITY_DEFAULT_THRESHOLD, help="Umbral de similitud (0..1) para reportar pares altamente parecidos")
parser.add_argument("--out", default=None, help="Archivo de salida (si no se da, se usa batch_evaluation.json en la carpeta) ")
args = parser.parse_args()
target_path = args.path
if not os.path.exists(target_path):
print(f"Error: ruta no existe: {target_path}")
sys.exit(2)
model_files = find_model_files(target_path)
if not model_files:
print("No se encontraron archivos .safetensor en la ruta dada.")
sys.exit(0)
num_classes, test_dl, sr_dl = build_dataloaders()
sr_mode = bool(args.sr)
metric_label = "psnr" if sr_mode else "accuracy"
results = []
# extraer métricas y vectores
file_vectors_by_len = {} # length -> [(idx, path, vector)]
for idx, fpath in enumerate(model_files):
entry = {
"path": fpath,
"sha256": None,
"metric": None,
"metric_label": metric_label,
"score": None,
"error": None,
"vector_len": 0,
}
try:
entry["sha256"] = obtener_sha256(fpath)
except Exception as e:
entry["error"] = f"Error calculando sha256: {e}"
results.append(entry)
continue
# metric
try:
if sr_mode:
metric_val = evaluation.cargar_evaluar_modelo_sr(fpath, sr_dl)
else:
metric_val = evaluation.cargar_evaluar_modelo_clasificacion(fpath, num_classes, test_dl)
if isinstance(metric_val, str):
entry["error"] = metric_val
else:
entry["metric"] = float(metric_val)
entry["score"] = calcular_puntaje(metric_val, model_type=(MODEL_TYPE_SR if sr_mode else MODEL_TYPE_CLASIFICACION))
except Exception as e:
entry["error"] = f"Error evaluando modelo: {e}"
# vector
model_cls = UNetSR if sr_mode else FromZero
try:
vec, vec_err, vec_len = extract_normalized_vector(model_cls, fpath, num_classes=(None if sr_mode else num_classes))
if vec_err:
entry["vector_error"] = vec_err
else:
entry["vector_len"] = int(vec_len)
# store vector grouping by length
file_vectors_by_len.setdefault(int(vec_len), []).append((idx, fpath, vec))
except Exception as e:
entry["vector_error"] = f"Error extrayendo vector: {e}"
results.append(entry)
# detectar duplicados exactos por sha
sha_groups = {}
for r in results:
sha = r.get("sha256")
if not sha:
continue
sha_groups.setdefault(sha, []).append(r["path"])
exact_duplicates = []
for sha, paths in sha_groups.items():
if len(paths) > 1:
exact_duplicates.append({"sha256": sha, "files": paths})
# mark in results
for r in results:
if r.get("sha256") == sha:
r["duplicado"] = True
# similitudes (por grupos con mismo vector_len)
similarity_pairs = pairwise_similarities(file_vectors_by_len, args.threshold)
# marcar notas para pares que sean exactos
for p in similarity_pairs:
pa = p["file_a"]
pb = p["file_b"]
# exact duplicate detection
sha_a = next((r["sha256"] for r in results if r["path"] == pa), None)
sha_b = next((r["sha256"] for r in results if r["path"] == pb), None)
if sha_a and sha_b and sha_a == sha_b:
p["note"] = "Exact duplicate (same sha256)"
else:
p["note"] = "High parameter similarity"
report = {
"evaluated_at": datetime.utcnow().isoformat() + "Z",
"mode": MODEL_TYPE_SR if sr_mode else MODEL_TYPE_CLASIFICACION,
"similarity_threshold": args.threshold,
"files": results,
"exact_duplicates": exact_duplicates,
"similar_pairs": similarity_pairs,
}
# output path
if args.out:
out_arg = args.out
if os.path.isabs(out_arg) or os.path.dirname(out_arg):
out_path = out_arg
else:
base = target_path if os.path.isdir(target_path) else os.path.dirname(target_path)
out_path = os.path.join(base, out_arg)
else:
base = target_path if os.path.isdir(target_path) else os.path.dirname(target_path)
out_path = os.path.join(base, "batch_evaluation.json")
try:
with open(out_path, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2, ensure_ascii=False)
print(f"Reporte guardado en: {out_path}")
except Exception as e:
print(f"Error guardando reporte: {e}")
sys.exit(3)
if __name__ == "__main__":
main()