#!/usr/bin/env python3 """Batch evaluator: recorre una carpeta (o archivo) y evalúa modelos. - Por defecto evalúa clasificación - Si se pasa --sr evalúa super-resolución - Genera un archivo JSON en la misma ruta con los resultados y pares de alta similitud """ import argparse import json import os import sys import math from datetime import datetime import torch import torch.nn.functional as F from safetensors.torch import load_model from utils import ( cargar_etiquetas, obtener_sha256, calcular_puntaje, MODEL_TYPE_CLASIFICACION, MODEL_TYPE_SR, ) import dataset import evaluation from models import FromZero, UNetSR SIMILARITY_DEFAULT_THRESHOLD = 0.95 def find_model_files(path): if os.path.isfile(path): if path.lower().endswith((".safetensor", ".safetensors")): return [os.path.abspath(path)] return [] files = [] for name in sorted(os.listdir(path)): if name.lower().endswith((".safetensor", ".safetensors")): files.append(os.path.join(path, name)) return files def build_dataloaders(): etiquetas, num_classes, codigo = cargar_etiquetas() test_dl, sr_dl = dataset.cargar_datasets(codigo) return num_classes, test_dl, sr_dl def extract_normalized_vector(model_cls, model_path, num_classes=None): """Carga el modelo en la arquitectura dada y devuelve su vector L2-normalizado. Devuelve (vector (1D torch.float32) | None, error_str | None, length_of_vector) """ try: model = model_cls(num_classes) if num_classes is not None else model_cls() except TypeError: # UNetSR() toma no args, FromZero requiere num_classes model = model_cls() try: load_model(model, model_path) model.to("cpu") parts = [] for p in model.parameters(): v = p.detach().cpu().view(-1).float() if v.numel() > 0: parts.append(v) if not parts: return None, "Modelo sin parámetros", 0 vec = torch.cat(parts) norm = float(vec.norm().item()) if norm > 0: vec = vec / norm return vec, None, vec.numel() except Exception as e: return None, str(e), 0 finally: # Liberar referencias try: del model except Exception: pass # best-effort if torch.cuda.is_available(): try: torch.cuda.empty_cache() except Exception: pass def pairwise_similarities(file_vectors_by_len, threshold): pairs = [] for length, items in file_vectors_by_len.items(): if len(items) < 2: continue # items: list of (index, filepath, vector) idxs = [it[0] for it in items] paths = [it[1] for it in items] vecs = [it[2] for it in items] # stack try: stack = torch.stack(vecs) # (N, L) # since vectors are normalized, dot product = cosine similarity sim_matrix = stack @ stack.t() n = sim_matrix.shape[0] for i in range(n): for j in range(i + 1, n): sim = float(sim_matrix[i, j].item()) if sim >= threshold: pairs.append({ "file_a": paths[i], "file_b": paths[j], "similarity": sim, }) except Exception: # fallback: compute pairwise with cosine_similarity for i in range(len(vecs)): for j in range(i + 1, len(vecs)): try: sim = float(F.cosine_similarity(vecs[i].unsqueeze(0), vecs[j].unsqueeze(0), dim=1).item()) if sim >= threshold: pairs.append({ "file_a": paths[i], "file_b": paths[j], "similarity": sim, }) except Exception: continue return pairs def main(): parser = argparse.ArgumentParser(description="Batch evaluation de modelos (.safetensor)") parser.add_argument("path", help="Archivo .safetensor o carpeta con modelos") parser.add_argument("--sr", action="store_true", help="Evaluar como super-resolución (SR)") parser.add_argument("--threshold", type=float, default=SIMILARITY_DEFAULT_THRESHOLD, help="Umbral de similitud (0..1) para reportar pares altamente parecidos") parser.add_argument("--out", default=None, help="Archivo de salida (si no se da, se usa batch_evaluation.json en la carpeta) ") args = parser.parse_args() target_path = args.path if not os.path.exists(target_path): print(f"Error: ruta no existe: {target_path}") sys.exit(2) model_files = find_model_files(target_path) if not model_files: print("No se encontraron archivos .safetensor en la ruta dada.") sys.exit(0) num_classes, test_dl, sr_dl = build_dataloaders() sr_mode = bool(args.sr) metric_label = "psnr" if sr_mode else "accuracy" results = [] # extraer métricas y vectores file_vectors_by_len = {} # length -> [(idx, path, vector)] for idx, fpath in enumerate(model_files): entry = { "path": fpath, "sha256": None, "metric": None, "metric_label": metric_label, "score": None, "error": None, "vector_len": 0, } try: entry["sha256"] = obtener_sha256(fpath) except Exception as e: entry["error"] = f"Error calculando sha256: {e}" results.append(entry) continue # metric try: if sr_mode: metric_val = evaluation.cargar_evaluar_modelo_sr(fpath, sr_dl) else: metric_val = evaluation.cargar_evaluar_modelo_clasificacion(fpath, num_classes, test_dl) if isinstance(metric_val, str): entry["error"] = metric_val else: entry["metric"] = float(metric_val) entry["score"] = calcular_puntaje(metric_val, model_type=(MODEL_TYPE_SR if sr_mode else MODEL_TYPE_CLASIFICACION)) except Exception as e: entry["error"] = f"Error evaluando modelo: {e}" # vector model_cls = UNetSR if sr_mode else FromZero try: vec, vec_err, vec_len = extract_normalized_vector(model_cls, fpath, num_classes=(None if sr_mode else num_classes)) if vec_err: entry["vector_error"] = vec_err else: entry["vector_len"] = int(vec_len) # store vector grouping by length file_vectors_by_len.setdefault(int(vec_len), []).append((idx, fpath, vec)) except Exception as e: entry["vector_error"] = f"Error extrayendo vector: {e}" results.append(entry) # detectar duplicados exactos por sha sha_groups = {} for r in results: sha = r.get("sha256") if not sha: continue sha_groups.setdefault(sha, []).append(r["path"]) exact_duplicates = [] for sha, paths in sha_groups.items(): if len(paths) > 1: exact_duplicates.append({"sha256": sha, "files": paths}) # mark in results for r in results: if r.get("sha256") == sha: r["duplicado"] = True # similitudes (por grupos con mismo vector_len) similarity_pairs = pairwise_similarities(file_vectors_by_len, args.threshold) # marcar notas para pares que sean exactos for p in similarity_pairs: pa = p["file_a"] pb = p["file_b"] # exact duplicate detection sha_a = next((r["sha256"] for r in results if r["path"] == pa), None) sha_b = next((r["sha256"] for r in results if r["path"] == pb), None) if sha_a and sha_b and sha_a == sha_b: p["note"] = "Exact duplicate (same sha256)" else: p["note"] = "High parameter similarity" report = { "evaluated_at": datetime.utcnow().isoformat() + "Z", "mode": MODEL_TYPE_SR if sr_mode else MODEL_TYPE_CLASIFICACION, "similarity_threshold": args.threshold, "files": results, "exact_duplicates": exact_duplicates, "similar_pairs": similarity_pairs, } # output path if args.out: out_arg = args.out if os.path.isabs(out_arg) or os.path.dirname(out_arg): out_path = out_arg else: base = target_path if os.path.isdir(target_path) else os.path.dirname(target_path) out_path = os.path.join(base, out_arg) else: base = target_path if os.path.isdir(target_path) else os.path.dirname(target_path) out_path = os.path.join(base, "batch_evaluation.json") try: with open(out_path, "w", encoding="utf-8") as f: json.dump(report, f, indent=2, ensure_ascii=False) print(f"Reporte guardado en: {out_path}") except Exception as e: print(f"Error guardando reporte: {e}") sys.exit(3) if __name__ == "__main__": main()