df-det / app.py
nicolasrl's picture
3 labels inference
6634367 verified
import os
import torch
import onnxruntime as ort
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files, HfFolder
CLASS_NAMES = ["Real", "Fake", "Semisynthetic"]
# ----------------------------------------------------------------------
# 1. Leer y guardar token de entorno
# ----------------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise EnvironmentError("No se encontr贸 la variable de entorno HF_TOKEN.")
HfFolder.save_token(HF_TOKEN)
# ----------------------------------------------------------------------
# 2. Localizar el modelo ONNX m谩s reciente en el repositorio
# ----------------------------------------------------------------------
repo_id = "nicolasrl/df-vs-real"
print("Buscando archivos 'model.onnx' en el repositorio...")
all_repo_files = list_repo_files(repo_id=repo_id)
onnx_repo_paths = [path for path in all_repo_files if path.endswith("model.onnx")]
if not onnx_repo_paths:
raise FileNotFoundError("No se encontr贸 ning煤n archivo model.onnx en el repositorio.")
latest_model_repo_path = max(onnx_repo_paths)
model_version = os.path.dirname(latest_model_repo_path)
print(f"Modelo m谩s reciente encontrado en el repo: {latest_model_repo_path} (Versi贸n: {model_version})")
# ----------------------------------------------------------------------
# 3. Descargar 煤nicamente el modelo ONNX seleccionado
# ----------------------------------------------------------------------
print(f"Descargando {latest_model_repo_path}...")
local_model_path = hf_hub_download(
repo_id=repo_id,
filename=latest_model_repo_path,
repo_type="model"
)
print(f"Modelo descargado en: {local_model_path}")
# ----------------------------------------------------------------------
# 4. Configurar sesi贸n ONNX Runtime
# ----------------------------------------------------------------------
providers = ['CPUExecutionProvider']
session = ort.InferenceSession(local_model_path, providers=providers)
# ----------------------------------------------------------------------
# 5. Preprocesamiento de im谩genes
# ----------------------------------------------------------------------
def preprocess(image: Image.Image):
image = image.convert("RGB").resize((224, 224))
img_array = np.array(image).astype(np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1) # CHW
img_array = np.expand_dims(img_array, 0) # batch
return img_array
# ----------------------------------------------------------------------
# 6. Inferencia
# ----------------------------------------------------------------------
def predict(image: Image.Image):
input_data = preprocess(image)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# El 'output' ahora tendr谩 3 valores (logits)
output = session.run([output_name], {input_name: input_data})[0]
# 1. Aplicar Softmax para convertir los logits en probabilidades
logits = output[0]
exp_logits = np.exp(logits - np.max(logits)) # Restar el max es por estabilidad num茅rica
probabilities = exp_logits / exp_logits.sum()
# 2. Crear el diccionario de resultados para Gradio
confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
# 3. Encontrar la etiqueta con la mayor probabilidad
predicted_index = np.argmax(probabilities)
label = CLASS_NAMES[predicted_index]
return confidences, f"Predicci贸n: {label}"
# ----------------------------------------------------------------------
# 7. Interfaz Gradio
# ----------------------------------------------------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Sube una imagen"),
outputs=[
gr.Label(label="Probabilidades"),
gr.Textbox(label="Resultado")
],
title="Detecci贸n de Deepfakes - ONNX",
description=f"Modelo ONNX cargado desde la versi贸n m谩s reciente del repo.\n**Modelo en uso:** {model_version}"
)
# ----------------------------------------------------------------------
# 8. Lanzar app
# ----------------------------------------------------------------------
if __name__ == "__main__":
demo.launch()