yukeshwaradse commited on
Commit
d959a95
·
verified ·
1 Parent(s): f19e145

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -113
app.py CHANGED
@@ -13,112 +13,7 @@ import numpy as np
13
  from PIL import Image
14
  import torch.nn.functional as F
15
 
16
- # transform = transforms.Compose([
17
- # transforms.Resize((128, 128)),
18
- # transforms.ToTensor(),
19
- # #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20
- # transforms.RandomHorizontalFlip(),
21
- # transforms.RandomRotation(10),
22
- # #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
23
- # ])
24
 
25
- # train_dataset = ImageFolder(root='data/train', transform=transform)
26
- # train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
27
-
28
- # val_dataset = ImageFolder(root='data/val', transform=transform)
29
- # val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
30
-
31
- # test_dataset = ImageFolder(root='data/val', transform=transform)
32
- # test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
33
-
34
- # class WeatherNet(nn.Module):
35
- # def __init__(self):
36
- # super(WeatherNet, self).__init__()
37
- # self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
38
- # self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
39
- # self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
40
- # self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
41
- # self.fc1 = nn.Linear(128 * 16 * 16, 512)
42
- # self.fc2 = nn.Linear(512, 11) # 11 classes
43
-
44
- # def forward(self, x):
45
- # x = self.pool(torch.relu(self.conv1(x)))
46
- # x = self.pool(torch.relu(self.conv2(x)))
47
- # x = self.pool(torch.relu(self.conv3(x)))
48
- # x = x.view(-1, 128 * 16 * 16)
49
- # x = torch.relu(self.fc1(x))
50
- # x = self.fc2(x)
51
- # return x
52
-
53
- # model = WeatherNet()
54
-
55
-
56
- # criterion = nn.CrossEntropyLoss()
57
- # optimizer = optim.Adam(model.parameters(), lr=0.001)
58
-
59
-
60
- # num_epochs = 10
61
- # for epoch in range(num_epochs):
62
- # model.train()
63
- # running_loss = 0.0
64
- # for images, labels in train_loader:
65
- # optimizer.zero_grad()
66
- # outputs = model(images)
67
- # loss = criterion(outputs, labels)
68
- # loss.backward()
69
- # optimizer.step()
70
- # running_loss += loss.item()
71
- # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
72
-
73
- # model.eval()
74
- # correct = 0
75
- # total = 0
76
- # with torch.no_grad():
77
- # for images, labels in val_loader:
78
- # outputs = model(images)
79
- # _, predicted = torch.max(outputs.data, 1)
80
- # total += labels.size(0)
81
- # correct += (predicted == labels).sum().item()
82
- # print(f'Validation Accuracy: {100 * correct / total:.2f}%')
83
- # torch.save(model.state_dict(), 'weather_model.pth')
84
-
85
-
86
- # def predict_image(image):
87
- # image = Image.fromarray(image).convert("RGB")
88
- # image = transform(image)
89
- # image = image.unsqueeze(0) # Add batch dimension
90
-
91
- # with torch.no_grad():
92
- # outputs = model(image)
93
- # probabilities = F.softmax(outputs, dim=1)
94
- # _, predicted = torch.max(outputs, 1)
95
-
96
- # predicted_label = classes[predicted.item()]
97
- # predicted_probability = probabilities[0][predicted.item()].item()
98
-
99
- # return predicted_label, predicted_probability
100
-
101
- # # Class labels
102
- # classes = ['dew', 'fog_smog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']
103
-
104
- # # Create Gradio interface
105
- # interface = gr.Interface(fn=predict_image,
106
- # inputs=gr.components.Image(),
107
- # outputs=[gr.components.Textbox(label="Predicted Label"), gr.components.Textbox(label="Prediction Probability")],title="Weather Detection App")
108
-
109
- # # Launch the interface
110
- # interface.launch()
111
-
112
- import gradio as gr
113
- import torch
114
- import torch.nn as nn
115
- import torchvision.transforms as transforms
116
- from torchvision.datasets import ImageFolder
117
- from torch.utils.data import DataLoader
118
- import torch.nn.functional as F
119
- from PIL import Image
120
-
121
- # Define the transformation
122
  transform = transforms.Compose([
123
  transforms.Resize((128, 128)),
124
  transforms.ToTensor(),
@@ -126,7 +21,7 @@ transform = transforms.Compose([
126
  transforms.RandomRotation(10),
127
  ])
128
 
129
- # Define the model
130
  class WeatherNet(nn.Module):
131
  def __init__(self):
132
  super(WeatherNet, self).__init__()
@@ -146,21 +41,17 @@ class WeatherNet(nn.Module):
146
  x = self.fc2(x)
147
  return x
148
 
149
- # Initialize the model
150
  model = WeatherNet()
151
 
152
- # Load the saved model weights
153
  model.load_state_dict(torch.load('weather_model.pth'))
154
  model.eval()
155
 
156
- # Class labels
157
  classes = ['dew', 'fog_smog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']
158
 
159
- # Prediction function
160
  def predict_image(image):
161
  image = Image.fromarray(image).convert("RGB")
162
  image = transform(image)
163
- image = image.unsqueeze(0) # Add batch dimension
164
 
165
  with torch.no_grad():
166
  outputs = model(image)
@@ -172,7 +63,7 @@ def predict_image(image):
172
 
173
  return predicted_label, predicted_probability
174
 
175
- # Create Gradio interface
176
  interface = gr.Interface(
177
  fn=predict_image,
178
  inputs=gr.components.Image(),
@@ -180,5 +71,4 @@ interface = gr.Interface(
180
  title="Weather Detection App"
181
  )
182
 
183
- # Launch the interface
184
  interface.launch()
 
13
  from PIL import Image
14
  import torch.nn.functional as F
15
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  transform = transforms.Compose([
18
  transforms.Resize((128, 128)),
19
  transforms.ToTensor(),
 
21
  transforms.RandomRotation(10),
22
  ])
23
 
24
+ #Remember peak acc was 72%
25
  class WeatherNet(nn.Module):
26
  def __init__(self):
27
  super(WeatherNet, self).__init__()
 
41
  x = self.fc2(x)
42
  return x
43
 
 
44
  model = WeatherNet()
45
 
 
46
  model.load_state_dict(torch.load('weather_model.pth'))
47
  model.eval()
48
 
 
49
  classes = ['dew', 'fog_smog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']
50
 
 
51
  def predict_image(image):
52
  image = Image.fromarray(image).convert("RGB")
53
  image = transform(image)
54
+ image = image.unsqueeze(0)
55
 
56
  with torch.no_grad():
57
  outputs = model(image)
 
63
 
64
  return predicted_label, predicted_probability
65
 
66
+ #Gradio interface
67
  interface = gr.Interface(
68
  fn=predict_image,
69
  inputs=gr.components.Image(),
 
71
  title="Weather Detection App"
72
  )
73
 
 
74
  interface.launch()