nicolasrl commited on
Commit
8b9a8b2
verified
1 Parent(s): e6030b4

takes last model from repo

Browse files
Files changed (1) hide show
  1. app.py +85 -32
app.py CHANGED
@@ -1,32 +1,85 @@
1
- import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoImageProcessor
3
- from PIL import Image
4
- import torch
5
- import os
6
-
7
- MODEL_ID = "nicolasrl/deepfake_vs_real_ViTlarge"
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
-
10
- model = AutoModelForImageClassification.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
11
- processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
12
-
13
- def predict(image: Image.Image):
14
- inputs = processor(images=image, return_tensors="pt")
15
- with torch.no_grad():
16
- outputs = model(**inputs)
17
- logits = outputs.logits
18
- probs = torch.nn.functional.softmax(logits, dim=-1)[0]
19
- results = {model.config.id2label[i]: float(probs[i]) for i in range(len(probs))}
20
- return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
21
-
22
- iface = gr.Interface(
23
- fn=predict,
24
- inputs=gr.Image(type="pil"),
25
- outputs=gr.Label(num_top_classes=2),
26
- title="Deepfake o Real",
27
- description="Sube una imagen y el modelo hace magia para predecir si es Deepfake o Real",
28
- api_name="predict"
29
- )
30
-
31
- if __name__ == "__main__":
32
- iface.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import onnxruntime as ort
5
+ import numpy as np
6
+ from PIL import Image
7
+ import gradio as gr
8
+ from huggingface_hub import snapshot_download
9
+
10
+ # ----------------------------------------------------------------------
11
+ # 1. Leer token de entorno
12
+ # ----------------------------------------------------------------------
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+ if not HF_TOKEN:
15
+ raise EnvironmentError("No se encontr贸 la variable de entorno HF_TOKEN. A帽谩dela en los Secrets del Space.")
16
+
17
+ # ----------------------------------------------------------------------
18
+ # 2. Descargar el repositorio autenticado
19
+ # ----------------------------------------------------------------------
20
+ repo_id = "nicolasrl/df-vs-real"
21
+ local_dir = snapshot_download(
22
+ repo_id,
23
+ local_dir="repo",
24
+ repo_type="model",
25
+ use_auth_token=HF_TOKEN
26
+ )
27
+
28
+ # ----------------------------------------------------------------------
29
+ # 3. Localizar el modelo ONNX m谩s reciente
30
+ # ----------------------------------------------------------------------
31
+ onnx_paths = glob.glob(os.path.join(local_dir, "**/model.onnx"), recursive=True)
32
+ if not onnx_paths:
33
+ raise FileNotFoundError("No se encontr贸 ning煤n archivo model.onnx en el repositorio.")
34
+
35
+ latest_model_path = max(onnx_paths, key=os.path.getmtime)
36
+ print(f"Usando modelo: {latest_model_path}")
37
+
38
+ # ----------------------------------------------------------------------
39
+ # 4. Configurar sesi贸n ONNX Runtime
40
+ # ----------------------------------------------------------------------
41
+ providers = ['CPUExecutionProvider']
42
+ session = ort.InferenceSession(latest_model_path, providers=providers)
43
+
44
+ # ----------------------------------------------------------------------
45
+ # 5. Preprocesamiento de im谩genes
46
+ # ----------------------------------------------------------------------
47
+ def preprocess(image: Image.Image):
48
+ image = image.convert("RGB").resize((224, 224))
49
+ img_array = np.array(image).astype(np.float32) / 255.0
50
+ img_array = img_array.transpose(2, 0, 1) # CHW
51
+ img_array = np.expand_dims(img_array, 0) # batch
52
+ return img_array
53
+
54
+ # ----------------------------------------------------------------------
55
+ # 6. Inferencia
56
+ # ----------------------------------------------------------------------
57
+ def predict(image: Image.Image):
58
+ input_data = preprocess(image)
59
+ input_name = session.get_inputs()[0].name
60
+ output_name = session.get_outputs()[0].name
61
+ output = session.run([output_name], {input_name: input_data})[0]
62
+
63
+ prob_fake = float(output[0][0])
64
+ label = "Deepfake" if prob_fake > 0.5 else "Real"
65
+ return {"Deepfake": prob_fake, "Real": 1 - prob_fake}, f"Predicci贸n: {label}"
66
+
67
+ # ----------------------------------------------------------------------
68
+ # 7. Interfaz Gradio
69
+ # ----------------------------------------------------------------------
70
+ demo = gr.Interface(
71
+ fn=predict,
72
+ inputs=gr.Image(type="pil", label="Sube una imagen"),
73
+ outputs=[
74
+ gr.Label(label="Probabilidades"),
75
+ gr.Textbox(label="Resultado")
76
+ ],
77
+ title="Detecci贸n de Deepfakes - ONNX",
78
+ description="Modelo ONNX cargado autom谩ticamente desde la versi贸n m谩s reciente del repo."
79
+ )
80
+
81
+ # ----------------------------------------------------------------------
82
+ # 8. Lanzar app
83
+ # ----------------------------------------------------------------------
84
+ if __name__ == "__main__":
85
+ demo.launch()