Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Batch evaluator: recorre una carpeta (o archivo) y evalúa modelos. | |
| - Por defecto evalúa clasificación | |
| - Si se pasa --sr evalúa super-resolución | |
| - Genera un archivo JSON en la misma ruta con los resultados y pares de alta similitud | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import math | |
| from datetime import datetime | |
| import torch | |
| import torch.nn.functional as F | |
| from safetensors.torch import load_model | |
| from utils import ( | |
| cargar_etiquetas, | |
| obtener_sha256, | |
| calcular_puntaje, | |
| MODEL_TYPE_CLASIFICACION, | |
| MODEL_TYPE_SR, | |
| ) | |
| import dataset | |
| import evaluation | |
| from models import FromZero, UNetSR | |
| SIMILARITY_DEFAULT_THRESHOLD = 0.95 | |
| def find_model_files(path): | |
| if os.path.isfile(path): | |
| if path.lower().endswith((".safetensor", ".safetensors")): | |
| return [os.path.abspath(path)] | |
| return [] | |
| files = [] | |
| for name in sorted(os.listdir(path)): | |
| if name.lower().endswith((".safetensor", ".safetensors")): | |
| files.append(os.path.join(path, name)) | |
| return files | |
| def build_dataloaders(): | |
| etiquetas, num_classes, codigo = cargar_etiquetas() | |
| test_dl, sr_dl = dataset.cargar_datasets(codigo) | |
| return num_classes, test_dl, sr_dl | |
| def extract_normalized_vector(model_cls, model_path, num_classes=None): | |
| """Carga el modelo en la arquitectura dada y devuelve su vector L2-normalizado. | |
| Devuelve (vector (1D torch.float32) | None, error_str | None, length_of_vector) | |
| """ | |
| try: | |
| model = model_cls(num_classes) if num_classes is not None else model_cls() | |
| except TypeError: | |
| # UNetSR() toma no args, FromZero requiere num_classes | |
| model = model_cls() | |
| try: | |
| load_model(model, model_path) | |
| model.to("cpu") | |
| parts = [] | |
| for p in model.parameters(): | |
| v = p.detach().cpu().view(-1).float() | |
| if v.numel() > 0: | |
| parts.append(v) | |
| if not parts: | |
| return None, "Modelo sin parámetros", 0 | |
| vec = torch.cat(parts) | |
| norm = float(vec.norm().item()) | |
| if norm > 0: | |
| vec = vec / norm | |
| return vec, None, vec.numel() | |
| except Exception as e: | |
| return None, str(e), 0 | |
| finally: | |
| # Liberar referencias | |
| try: | |
| del model | |
| except Exception: | |
| pass | |
| # best-effort | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| def pairwise_similarities(file_vectors_by_len, threshold): | |
| pairs = [] | |
| for length, items in file_vectors_by_len.items(): | |
| if len(items) < 2: | |
| continue | |
| # items: list of (index, filepath, vector) | |
| idxs = [it[0] for it in items] | |
| paths = [it[1] for it in items] | |
| vecs = [it[2] for it in items] | |
| # stack | |
| try: | |
| stack = torch.stack(vecs) # (N, L) | |
| # since vectors are normalized, dot product = cosine similarity | |
| sim_matrix = stack @ stack.t() | |
| n = sim_matrix.shape[0] | |
| for i in range(n): | |
| for j in range(i + 1, n): | |
| sim = float(sim_matrix[i, j].item()) | |
| if sim >= threshold: | |
| pairs.append({ | |
| "file_a": paths[i], | |
| "file_b": paths[j], | |
| "similarity": sim, | |
| }) | |
| except Exception: | |
| # fallback: compute pairwise with cosine_similarity | |
| for i in range(len(vecs)): | |
| for j in range(i + 1, len(vecs)): | |
| try: | |
| sim = float(F.cosine_similarity(vecs[i].unsqueeze(0), vecs[j].unsqueeze(0), dim=1).item()) | |
| if sim >= threshold: | |
| pairs.append({ | |
| "file_a": paths[i], | |
| "file_b": paths[j], | |
| "similarity": sim, | |
| }) | |
| except Exception: | |
| continue | |
| return pairs | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Batch evaluation de modelos (.safetensor)") | |
| parser.add_argument("path", help="Archivo .safetensor o carpeta con modelos") | |
| parser.add_argument("--sr", action="store_true", help="Evaluar como super-resolución (SR)") | |
| parser.add_argument("--threshold", type=float, default=SIMILARITY_DEFAULT_THRESHOLD, help="Umbral de similitud (0..1) para reportar pares altamente parecidos") | |
| parser.add_argument("--out", default=None, help="Archivo de salida (si no se da, se usa batch_evaluation.json en la carpeta) ") | |
| args = parser.parse_args() | |
| target_path = args.path | |
| if not os.path.exists(target_path): | |
| print(f"Error: ruta no existe: {target_path}") | |
| sys.exit(2) | |
| model_files = find_model_files(target_path) | |
| if not model_files: | |
| print("No se encontraron archivos .safetensor en la ruta dada.") | |
| sys.exit(0) | |
| num_classes, test_dl, sr_dl = build_dataloaders() | |
| sr_mode = bool(args.sr) | |
| metric_label = "psnr" if sr_mode else "accuracy" | |
| results = [] | |
| # extraer métricas y vectores | |
| file_vectors_by_len = {} # length -> [(idx, path, vector)] | |
| for idx, fpath in enumerate(model_files): | |
| entry = { | |
| "path": fpath, | |
| "sha256": None, | |
| "metric": None, | |
| "metric_label": metric_label, | |
| "score": None, | |
| "error": None, | |
| "vector_len": 0, | |
| } | |
| try: | |
| entry["sha256"] = obtener_sha256(fpath) | |
| except Exception as e: | |
| entry["error"] = f"Error calculando sha256: {e}" | |
| results.append(entry) | |
| continue | |
| # metric | |
| try: | |
| if sr_mode: | |
| metric_val = evaluation.cargar_evaluar_modelo_sr(fpath, sr_dl) | |
| else: | |
| metric_val = evaluation.cargar_evaluar_modelo_clasificacion(fpath, num_classes, test_dl) | |
| if isinstance(metric_val, str): | |
| entry["error"] = metric_val | |
| else: | |
| entry["metric"] = float(metric_val) | |
| entry["score"] = calcular_puntaje(metric_val, model_type=(MODEL_TYPE_SR if sr_mode else MODEL_TYPE_CLASIFICACION)) | |
| except Exception as e: | |
| entry["error"] = f"Error evaluando modelo: {e}" | |
| # vector | |
| model_cls = UNetSR if sr_mode else FromZero | |
| try: | |
| vec, vec_err, vec_len = extract_normalized_vector(model_cls, fpath, num_classes=(None if sr_mode else num_classes)) | |
| if vec_err: | |
| entry["vector_error"] = vec_err | |
| else: | |
| entry["vector_len"] = int(vec_len) | |
| # store vector grouping by length | |
| file_vectors_by_len.setdefault(int(vec_len), []).append((idx, fpath, vec)) | |
| except Exception as e: | |
| entry["vector_error"] = f"Error extrayendo vector: {e}" | |
| results.append(entry) | |
| # detectar duplicados exactos por sha | |
| sha_groups = {} | |
| for r in results: | |
| sha = r.get("sha256") | |
| if not sha: | |
| continue | |
| sha_groups.setdefault(sha, []).append(r["path"]) | |
| exact_duplicates = [] | |
| for sha, paths in sha_groups.items(): | |
| if len(paths) > 1: | |
| exact_duplicates.append({"sha256": sha, "files": paths}) | |
| # mark in results | |
| for r in results: | |
| if r.get("sha256") == sha: | |
| r["duplicado"] = True | |
| # similitudes (por grupos con mismo vector_len) | |
| similarity_pairs = pairwise_similarities(file_vectors_by_len, args.threshold) | |
| # marcar notas para pares que sean exactos | |
| for p in similarity_pairs: | |
| pa = p["file_a"] | |
| pb = p["file_b"] | |
| # exact duplicate detection | |
| sha_a = next((r["sha256"] for r in results if r["path"] == pa), None) | |
| sha_b = next((r["sha256"] for r in results if r["path"] == pb), None) | |
| if sha_a and sha_b and sha_a == sha_b: | |
| p["note"] = "Exact duplicate (same sha256)" | |
| else: | |
| p["note"] = "High parameter similarity" | |
| report = { | |
| "evaluated_at": datetime.utcnow().isoformat() + "Z", | |
| "mode": MODEL_TYPE_SR if sr_mode else MODEL_TYPE_CLASIFICACION, | |
| "similarity_threshold": args.threshold, | |
| "files": results, | |
| "exact_duplicates": exact_duplicates, | |
| "similar_pairs": similarity_pairs, | |
| } | |
| # output path | |
| if args.out: | |
| out_arg = args.out | |
| if os.path.isabs(out_arg) or os.path.dirname(out_arg): | |
| out_path = out_arg | |
| else: | |
| base = target_path if os.path.isdir(target_path) else os.path.dirname(target_path) | |
| out_path = os.path.join(base, out_arg) | |
| else: | |
| base = target_path if os.path.isdir(target_path) else os.path.dirname(target_path) | |
| out_path = os.path.join(base, "batch_evaluation.json") | |
| try: | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| json.dump(report, f, indent=2, ensure_ascii=False) | |
| print(f"Reporte guardado en: {out_path}") | |
| except Exception as e: | |
| print(f"Error guardando reporte: {e}") | |
| sys.exit(3) | |
| if __name__ == "__main__": | |
| main() | |