File size: 4,283 Bytes
8b9a8b2
 
 
 
 
 
2b19dce
8b9a8b2
6634367
 
 
 
8b9a8b2
b110258
8b9a8b2
 
 
6634367
8b9a8b2
b110258
 
8b9a8b2
b110258
8b9a8b2
 
 
b110258
2b19dce
 
 
6634367
8b9a8b2
 
b110258
 
 
 
 
 
 
 
 
 
 
 
 
 
8b9a8b2
 
 
 
 
b110258
8b9a8b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6634367
 
8b9a8b2
6634367
 
 
 
 
 
 
 
 
 
 
 
8b9a8b2
6634367
8b9a8b2
 
 
 
 
 
 
 
 
 
 
 
b110258
8b9a8b2
 
 
 
 
 
b110258
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
import onnxruntime as ort
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download, list_repo_files, HfFolder


CLASS_NAMES = ["Real", "Fake", "Semisynthetic"]


# ----------------------------------------------------------------------
# 1. Leer y guardar token de entorno
# ----------------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise EnvironmentError("No se encontr贸 la variable de entorno HF_TOKEN.")

HfFolder.save_token(HF_TOKEN)

# ----------------------------------------------------------------------
# 2. Localizar el modelo ONNX m谩s reciente en el repositorio
# ----------------------------------------------------------------------
repo_id = "nicolasrl/df-vs-real"

print("Buscando archivos 'model.onnx' en el repositorio...")
all_repo_files = list_repo_files(repo_id=repo_id)
onnx_repo_paths = [path for path in all_repo_files if path.endswith("model.onnx")]

if not onnx_repo_paths:
    raise FileNotFoundError("No se encontr贸 ning煤n archivo model.onnx en el repositorio.")

latest_model_repo_path = max(onnx_repo_paths)
model_version = os.path.dirname(latest_model_repo_path)
print(f"Modelo m谩s reciente encontrado en el repo: {latest_model_repo_path} (Versi贸n: {model_version})")

# ----------------------------------------------------------------------
# 3. Descargar 煤nicamente el modelo ONNX seleccionado
# ----------------------------------------------------------------------
print(f"Descargando {latest_model_repo_path}...")
local_model_path = hf_hub_download(
    repo_id=repo_id,
    filename=latest_model_repo_path,
    repo_type="model"
)
print(f"Modelo descargado en: {local_model_path}")

# ----------------------------------------------------------------------
# 4. Configurar sesi贸n ONNX Runtime
# ----------------------------------------------------------------------
providers = ['CPUExecutionProvider']
session = ort.InferenceSession(local_model_path, providers=providers)

# ----------------------------------------------------------------------
# 5. Preprocesamiento de im谩genes
# ----------------------------------------------------------------------
def preprocess(image: Image.Image):
    image = image.convert("RGB").resize((224, 224))
    img_array = np.array(image).astype(np.float32) / 255.0
    img_array = img_array.transpose(2, 0, 1)  # CHW
    img_array = np.expand_dims(img_array, 0)  # batch
    return img_array

# ----------------------------------------------------------------------
# 6. Inferencia
# ----------------------------------------------------------------------
def predict(image: Image.Image):
    input_data = preprocess(image)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    # El 'output' ahora tendr谩 3 valores (logits)
    output = session.run([output_name], {input_name: input_data})[0]
    
    # 1. Aplicar Softmax para convertir los logits en probabilidades
    logits = output[0]
    exp_logits = np.exp(logits - np.max(logits)) # Restar el max es por estabilidad num茅rica
    probabilities = exp_logits / exp_logits.sum()

    # 2. Crear el diccionario de resultados para Gradio
    confidences = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
    
    # 3. Encontrar la etiqueta con la mayor probabilidad
    predicted_index = np.argmax(probabilities)
    label = CLASS_NAMES[predicted_index]

    return confidences, f"Predicci贸n: {label}"

# ----------------------------------------------------------------------
# 7. Interfaz Gradio
# ----------------------------------------------------------------------
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Sube una imagen"),
    outputs=[
        gr.Label(label="Probabilidades"),
        gr.Textbox(label="Resultado")
    ],
    title="Detecci贸n de Deepfakes - ONNX",
    description=f"Modelo ONNX cargado desde la versi贸n m谩s reciente del repo.\n**Modelo en uso:** {model_version}"
)

# ----------------------------------------------------------------------
# 8. Lanzar app
# ----------------------------------------------------------------------
if __name__ == "__main__":
    demo.launch()