bithal26 commited on
Commit
bff8c75
·
verified ·
1 Parent(s): aaf81e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.modules.dropout import Dropout
4
+ from torch.nn.modules.linear import Linear
5
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
6
+ from timm.models.efficientnet import tf_efficientnet_b7_ns
7
+ from functools import partial
8
+ import re
9
+ import gradio as gr
10
+ import os
11
+
12
+ # --- 1. MODEL ARCHITECTURE ---
13
+ encoder_params = {
14
+ "tf_efficientnet_b7_ns": {
15
+ "features": 2560,
16
+ "init_op": partial(tf_efficientnet_b7_ns, pretrained=False, drop_path_rate=0.2)
17
+ }
18
+ }
19
+
20
+ class DeepFakeClassifier(nn.Module):
21
+ def __init__(self, encoder="tf_efficientnet_b7_ns", dropout_rate=0.0) -> None:
22
+ super().__init__()
23
+ self.encoder = encoder_params[encoder]["init_op"]()
24
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
25
+ self.dropout = Dropout(dropout_rate)
26
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
27
+
28
+ def forward(self, x):
29
+ x = self.encoder.forward_features(x)
30
+ x = self.avg_pool(x).flatten(1)
31
+ x = self.dropout(x)
32
+ x = self.fc(x)
33
+ return x
34
+
35
+ # --- 2. LOAD THE SPECIFIC WEIGHT FILE ---
36
+ # >>> UPDATE THIS STRING IN EVERY SPACE TO MATCH THE UPLOADED FILE EXACTLY <<<
37
+ WEIGHT_FILE = "final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29"
38
+
39
+ print(f"Booting API Worker: Loading {WEIGHT_FILE}...")
40
+ device = torch.device('cpu')
41
+ model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to(device)
42
+
43
+ # Using weights_only=False to bypass PyTorch 2.6 security restriction
44
+ checkpoint = torch.load(WEIGHT_FILE, map_location="cpu", weights_only=False)
45
+ state_dict = checkpoint.get("state_dict", checkpoint)
46
+ model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
47
+ model.eval()
48
+ print("Model loaded successfully. API Ready.")
49
+
50
+ # --- 3. THE API INFERENCE FUNCTION ---
51
+ def predict_tensor(tensor_file):
52
+ """
53
+ Receives a preprocessed .pt tensor file from the Master UI,
54
+ runs inference, and returns the confidence score array.
55
+ """
56
+ if tensor_file is None:
57
+ return {"error": "No tensor file received"}
58
+
59
+ try:
60
+ # Load the pre-processed tensor sent by the Master UI
61
+ x = torch.load(tensor_file.name, map_location=device, weights_only=True)
62
+
63
+ with torch.no_grad():
64
+ y_pred = model(x)
65
+ y_pred = torch.sigmoid(y_pred.squeeze())
66
+
67
+ # Format output so it can be sent back via JSON
68
+ if y_pred.dim() == 0:
69
+ bpred = [float(y_pred.cpu().numpy())]
70
+ else:
71
+ bpred = y_pred.cpu().numpy().tolist()
72
+
73
+ return {"predictions": bpred}
74
+ except Exception as e:
75
+ return {"error": str(e)}
76
+
77
+ # --- 4. GRADIO API INTERFACE ---
78
+ interface = gr.Interface(
79
+ fn=predict_tensor,
80
+ inputs=gr.File(label="Input Tensor (.pt)"),
81
+ outputs=gr.JSON(label="Prediction Array"),
82
+ title="DeepGuard Worker API",
83
+ description="Microservice endpoint for EfficientNet-B7 Deepfake Inference."
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ interface.launch()