matikosowy commited on
Commit
e189f0a
·
verified ·
1 Parent(s): 772f03b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -5,6 +5,7 @@ from torchvision import transforms
5
  import torchvision.models as models
6
  import torch.nn as nn
7
 
 
8
  class DummyModel(nn.Module):
9
  def __init__(self):
10
  super(DummyModel, self).__init__()
@@ -80,21 +81,32 @@ class DummyModel(nn.Module):
80
 
81
  return dec4
82
 
 
83
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
  model = DummyModel()
85
- model.load_state_dict(torch.load('model.pth'))
86
  model = model.to(device)
87
  model.eval()
88
 
 
89
  # Define preprocessing transforms
90
- preprocess = transforms.Compose([
91
- transforms.Resize(150),
92
  transforms.ToTensor(),
93
- transforms.Normalize([0.5), [0.5])
94
  ])
95
 
 
 
 
 
 
 
 
 
 
96
  def predict(image):
97
- image = preprocess(image).to(model.device)
98
  with torch.no_grad():
99
  output = model(image)
100
 
@@ -102,7 +114,7 @@ def predict(image):
102
 
103
  return image
104
 
105
- # Create Gradio interface
106
  iface = gr.Interface(fn=predict,
107
  inputs=gr.Image(type="pil"),
108
  outputs=gr.Image(type="pil"))
 
5
  import torchvision.models as models
6
  import torch.nn as nn
7
 
8
+
9
  class DummyModel(nn.Module):
10
  def __init__(self):
11
  super(DummyModel, self).__init__()
 
81
 
82
  return dec4
83
 
84
+
85
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
  model = DummyModel()
87
+ model.load_state_dict(torch.load('model.pth', weights_only=True))
88
  model = model.to(device)
89
  model.eval()
90
 
91
+
92
  # Define preprocessing transforms
93
+ transform = transforms.Compose([
94
+ transforms.Resize((150, 150)),
95
  transforms.ToTensor(),
96
+ transforms.Normalize([0.5], [0.5])
97
  ])
98
 
99
+
100
+ def preprocess(image):
101
+ image = image.convert('L')
102
+ image = transform(image)
103
+ image = image.unsqueeze(0)
104
+
105
+ return image
106
+
107
+
108
  def predict(image):
109
+ image = preprocess(image).to(device)
110
  with torch.no_grad():
111
  output = model(image)
112
 
 
114
 
115
  return image
116
 
117
+
118
  iface = gr.Interface(fn=predict,
119
  inputs=gr.Image(type="pil"),
120
  outputs=gr.Image(type="pil"))