Rujit commited on
Commit
0fe1a7a
·
verified ·
1 Parent(s): aeea6fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -4,6 +4,7 @@ from torchvision import models, transforms
4
  from PIL import Image
5
  import torch.nn.functional as F
6
  import torch.nn as nn
 
7
 
8
  # Class labels
9
  class_names = ['fake', 'real']
@@ -35,19 +36,26 @@ model, device = load_model()
35
 
36
  # Inference function
37
  def predict(image):
 
 
 
 
 
38
  if image.mode == "RGBA":
39
  image = image.convert("RGB")
40
-
 
41
  image = data_transforms(image).unsqueeze(0).to(device)
 
42
  with torch.no_grad():
43
  outputs = model(image)
44
  probs = F.softmax(outputs, dim=1)
45
  conf, pred = torch.max(probs, 1)
46
-
47
- label = class_names[pred.item()]
48
- confidence = f"{conf.item() * 100:.2f}%"
49
  return f"{label} ({confidence})"
50
 
51
  # Gradio interface
52
  demo = gr.Interface(fn=predict, inputs="image", outputs="text", api_name="predict")
53
- demo.launch()
 
4
  from PIL import Image
5
  import torch.nn.functional as F
6
  import torch.nn as nn
7
+ import numpy as np
8
 
9
  # Class labels
10
  class_names = ['fake', 'real']
 
36
 
37
  # Inference function
38
  def predict(image):
39
+ # Convert numpy array to PIL Image if needed
40
+ if isinstance(image, np.ndarray):
41
+ image = Image.fromarray(image)
42
+
43
+ # Convert RGBA to RGB if needed
44
  if image.mode == "RGBA":
45
  image = image.convert("RGB")
46
+
47
+ # Apply transforms
48
  image = data_transforms(image).unsqueeze(0).to(device)
49
+
50
  with torch.no_grad():
51
  outputs = model(image)
52
  probs = F.softmax(outputs, dim=1)
53
  conf, pred = torch.max(probs, 1)
54
+ label = class_names[pred.item()]
55
+ confidence = f"{conf.item() * 100:.2f}%"
56
+
57
  return f"{label} ({confidence})"
58
 
59
  # Gradio interface
60
  demo = gr.Interface(fn=predict, inputs="image", outputs="text", api_name="predict")
61
+ demo.launch()