import os import torch import onnxruntime as ort import numpy as np from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download, list_repo_files, HfFolder CLASS_NAMES = ["Real", "Fake", "Semisynthetic"] # ---------------------------------------------------------------------- # 1. Leer y guardar token de entorno # ---------------------------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise EnvironmentError("No se encontró la variable de entorno HF_TOKEN.") HfFolder.save_token(HF_TOKEN) # ---------------------------------------------------------------------- # 2. Localizar el modelo ONNX más reciente en el repositorio # ---------------------------------------------------------------------- repo_id = "nicolasrl/df-vs-real" print("Buscando archivos 'model.onnx' en el repositorio...") all_repo_files = list_repo_files(repo_id=repo_id) onnx_repo_paths = [path for path in all_repo_files if path.endswith("model.onnx")] if not onnx_repo_paths: raise FileNotFoundError("No se encontró ningún archivo model.onnx en el repositorio.") latest_model_repo_path = max(onnx_repo_paths) model_version = os.path.dirname(latest_model_repo_path) print(f"Modelo más reciente encontrado en el repo: {latest_model_repo_path} (Versión: {model_version})") # ---------------------------------------------------------------------- # 3. Descargar únicamente el modelo ONNX seleccionado # ---------------------------------------------------------------------- print(f"Descargando {latest_model_repo_path}...") local_model_path = hf_hub_download( repo_id=repo_id, filename=latest_model_repo_path, repo_type="model" ) print(f"Modelo descargado en: {local_model_path}") # ---------------------------------------------------------------------- # 4. Configurar sesión ONNX Runtime # ---------------------------------------------------------------------- providers = ['CPUExecutionProvider'] session = ort.InferenceSession(local_model_path, providers=providers) # ---------------------------------------------------------------------- # 5. Preprocesamiento de imágenes # ---------------------------------------------------------------------- def preprocess(image: Image.Image): image = image.convert("RGB").resize((224, 224)) img_array = np.array(image).astype(np.float32) / 255.0 img_array = img_array.transpose(2, 0, 1) # CHW img_array = np.expand_dims(img_array, 0) # batch return img_array # ---------------------------------------------------------------------- # 6. Inferencia # ---------------------------------------------------------------------- def predict(image: Image.Image): input_data = preprocess(image) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name # El 'output' ahora tendrá 3 valores (logits) output = session.run([output_name], {input_name: input_data})[0] # 1. Aplicar Softmax para convertir los logits en probabilidades logits = output[0] exp_logits = np.exp(logits - np.max(logits)) # Restar el max es por estabilidad numérica probabilities = exp_logits / exp_logits.sum() # 2. Crear el diccionario de resultados para Gradio confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))} # 3. Encontrar la etiqueta con la mayor probabilidad predicted_index = np.argmax(probabilities) label = CLASS_NAMES[predicted_index] return confidences, f"Predicción: {label}" # ---------------------------------------------------------------------- # 7. Interfaz Gradio # ---------------------------------------------------------------------- demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Sube una imagen"), outputs=[ gr.Label(label="Probabilidades"), gr.Textbox(label="Resultado") ], title="Detección de Deepfakes - ONNX", description=f"Modelo ONNX cargado desde la versión más reciente del repo.\n**Modelo en uso:** {model_version}" ) # ---------------------------------------------------------------------- # 8. Lanzar app # ---------------------------------------------------------------------- if __name__ == "__main__": demo.launch()