File size: 4,283 Bytes
8b9a8b2 2b19dce 8b9a8b2 6634367 8b9a8b2 b110258 8b9a8b2 6634367 8b9a8b2 b110258 8b9a8b2 b110258 8b9a8b2 b110258 2b19dce 6634367 8b9a8b2 b110258 8b9a8b2 b110258 8b9a8b2 6634367 8b9a8b2 6634367 8b9a8b2 6634367 8b9a8b2 b110258 8b9a8b2 b110258 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | import os
import torch
import onnxruntime as ort
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files, HfFolder
CLASS_NAMES = ["Real", "Fake", "Semisynthetic"]
# ----------------------------------------------------------------------
# 1. Leer y guardar token de entorno
# ----------------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise EnvironmentError("No se encontr贸 la variable de entorno HF_TOKEN.")
HfFolder.save_token(HF_TOKEN)
# ----------------------------------------------------------------------
# 2. Localizar el modelo ONNX m谩s reciente en el repositorio
# ----------------------------------------------------------------------
repo_id = "nicolasrl/df-vs-real"
print("Buscando archivos 'model.onnx' en el repositorio...")
all_repo_files = list_repo_files(repo_id=repo_id)
onnx_repo_paths = [path for path in all_repo_files if path.endswith("model.onnx")]
if not onnx_repo_paths:
raise FileNotFoundError("No se encontr贸 ning煤n archivo model.onnx en el repositorio.")
latest_model_repo_path = max(onnx_repo_paths)
model_version = os.path.dirname(latest_model_repo_path)
print(f"Modelo m谩s reciente encontrado en el repo: {latest_model_repo_path} (Versi贸n: {model_version})")
# ----------------------------------------------------------------------
# 3. Descargar 煤nicamente el modelo ONNX seleccionado
# ----------------------------------------------------------------------
print(f"Descargando {latest_model_repo_path}...")
local_model_path = hf_hub_download(
repo_id=repo_id,
filename=latest_model_repo_path,
repo_type="model"
)
print(f"Modelo descargado en: {local_model_path}")
# ----------------------------------------------------------------------
# 4. Configurar sesi贸n ONNX Runtime
# ----------------------------------------------------------------------
providers = ['CPUExecutionProvider']
session = ort.InferenceSession(local_model_path, providers=providers)
# ----------------------------------------------------------------------
# 5. Preprocesamiento de im谩genes
# ----------------------------------------------------------------------
def preprocess(image: Image.Image):
image = image.convert("RGB").resize((224, 224))
img_array = np.array(image).astype(np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1) # CHW
img_array = np.expand_dims(img_array, 0) # batch
return img_array
# ----------------------------------------------------------------------
# 6. Inferencia
# ----------------------------------------------------------------------
def predict(image: Image.Image):
input_data = preprocess(image)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# El 'output' ahora tendr谩 3 valores (logits)
output = session.run([output_name], {input_name: input_data})[0]
# 1. Aplicar Softmax para convertir los logits en probabilidades
logits = output[0]
exp_logits = np.exp(logits - np.max(logits)) # Restar el max es por estabilidad num茅rica
probabilities = exp_logits / exp_logits.sum()
# 2. Crear el diccionario de resultados para Gradio
confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
# 3. Encontrar la etiqueta con la mayor probabilidad
predicted_index = np.argmax(probabilities)
label = CLASS_NAMES[predicted_index]
return confidences, f"Predicci贸n: {label}"
# ----------------------------------------------------------------------
# 7. Interfaz Gradio
# ----------------------------------------------------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Sube una imagen"),
outputs=[
gr.Label(label="Probabilidades"),
gr.Textbox(label="Resultado")
],
title="Detecci贸n de Deepfakes - ONNX",
description=f"Modelo ONNX cargado desde la versi贸n m谩s reciente del repo.\n**Modelo en uso:** {model_version}"
)
# ----------------------------------------------------------------------
# 8. Lanzar app
# ----------------------------------------------------------------------
if __name__ == "__main__":
demo.launch() |