bithal26 commited on
Commit
0bd2258
·
verified ·
1 Parent(s): d1e7739

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -24
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from torch import nn
3
  from torch.nn.modules.dropout import Dropout
@@ -9,7 +11,7 @@ import re
9
  import gradio as gr
10
  import os
11
 
12
- # --- 1. MODEL ARCHITECTURE ---
13
  encoder_params = {
14
  "tf_efficientnet_b7_ns": {
15
  "features": 2560,
@@ -32,39 +34,27 @@ class DeepFakeClassifier(nn.Module):
32
  x = self.fc(x)
33
  return x
34
 
35
- # --- 2. LOAD THE SPECIFIC WEIGHT FILE ---
36
- # >>> UPDATE THIS STRING IN EVERY SPACE TO MATCH THE UPLOADED FILE EXACTLY <<<
37
- WEIGHT_FILE = "final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19"
38
-
39
  print(f"Booting API Worker: Loading {WEIGHT_FILE}...")
40
  device = torch.device('cpu')
41
  model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to(device)
42
 
43
- # Using weights_only=False to bypass PyTorch 2.6 security restriction
44
  checkpoint = torch.load(WEIGHT_FILE, map_location="cpu", weights_only=False)
45
  state_dict = checkpoint.get("state_dict", checkpoint)
46
  model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
47
  model.eval()
48
- print("Model loaded successfully. API Ready.")
49
 
50
- # --- 3. THE API INFERENCE FUNCTION ---
51
  def predict_tensor(tensor_file):
52
- """
53
- Receives a preprocessed .pt tensor file from the Master UI,
54
- runs inference, and returns the confidence score array.
55
- """
56
  if tensor_file is None:
57
  return {"error": "No tensor file received"}
58
-
59
  try:
60
- # Load the pre-processed tensor sent by the Master UI
61
  x = torch.load(tensor_file.name, map_location=device, weights_only=True)
62
-
63
  with torch.no_grad():
64
  y_pred = model(x)
65
  y_pred = torch.sigmoid(y_pred.squeeze())
66
 
67
- # Format output so it can be sent back via JSON
68
  if y_pred.dim() == 0:
69
  bpred = [float(y_pred.cpu().numpy())]
70
  else:
@@ -74,14 +64,7 @@ def predict_tensor(tensor_file):
74
  except Exception as e:
75
  return {"error": str(e)}
76
 
77
- # --- 4. GRADIO API INTERFACE ---
78
- interface = gr.Interface(
79
- fn=predict_tensor,
80
- inputs=gr.File(label="Input Tensor (.pt)"),
81
- outputs=gr.JSON(label="Prediction Array"),
82
- title="DeepGuard Worker API",
83
- description="Microservice endpoint for EfficientNet-B7 Deepfake Inference."
84
- )
85
 
86
  if __name__ == "__main__":
87
  interface.launch()
 
1
+ WEIGHT_FILE = "final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19"
2
+
3
  import torch
4
  from torch import nn
5
  from torch.nn.modules.dropout import Dropout
 
11
  import gradio as gr
12
  import os
13
 
14
+ # --- 1. MODEL ARCHITECTURE (Matches Notebook Exactly) ---
15
  encoder_params = {
16
  "tf_efficientnet_b7_ns": {
17
  "features": 2560,
 
34
  x = self.fc(x)
35
  return x
36
 
 
 
 
 
37
  print(f"Booting API Worker: Loading {WEIGHT_FILE}...")
38
  device = torch.device('cpu')
39
  model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to(device)
40
 
 
41
  checkpoint = torch.load(WEIGHT_FILE, map_location="cpu", weights_only=False)
42
  state_dict = checkpoint.get("state_dict", checkpoint)
43
  model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
44
  model.eval()
 
45
 
46
+ # --- 3. API ENDPOINT ---
47
  def predict_tensor(tensor_file):
 
 
 
 
48
  if tensor_file is None:
49
  return {"error": "No tensor file received"}
 
50
  try:
51
+ # Load math tensor
52
  x = torch.load(tensor_file.name, map_location=device, weights_only=True)
 
53
  with torch.no_grad():
54
  y_pred = model(x)
55
  y_pred = torch.sigmoid(y_pred.squeeze())
56
 
57
+ # Format to standard list
58
  if y_pred.dim() == 0:
59
  bpred = [float(y_pred.cpu().numpy())]
60
  else:
 
64
  except Exception as e:
65
  return {"error": str(e)}
66
 
67
+ interface = gr.Interface(fn=predict_tensor, inputs=gr.File(label="Input Tensor (.pt)"), outputs=gr.JSON())
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  interface.launch()