yukeshwaradse commited on
Commit
5078ec4
·
verified ·
1 Parent(s): e9fa9ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py CHANGED
@@ -1,4 +1,84 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import wandb
3
  def predict_image(image):
4
  image = Image.fromarray(image).convert("RGB")
 
1
  import gradio as gr
2
+ import wandb
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torchvision.transforms as transforms
7
+ from torchvision.datasets import ImageFolder
8
+ from torch.utils.data import DataLoader
9
+ import matplotlib.pyplot as plt
10
+ import torchvision
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+
14
+ transform = transforms.Compose([
15
+ transforms.Resize((128, 128)),
16
+ transforms.ToTensor(),
17
+ #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
+ transforms.RandomHorizontalFlip(),
19
+ transforms.RandomRotation(10),
20
+ #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
21
+ ])
22
+
23
+ train_dataset = ImageFolder(root='/content/data/train', transform=transform)
24
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
25
+
26
+ val_dataset = ImageFolder(root='/content/data/val', transform=transform)
27
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
28
+
29
+ test_dataset = ImageFolder(root='/content/data/val', transform=transform)
30
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
31
+
32
+ class WeatherNet(nn.Module):
33
+ def __init__(self):
34
+ super(WeatherNet, self).__init__()
35
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
36
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
37
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
38
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
39
+ self.fc1 = nn.Linear(128 * 16 * 16, 512)
40
+ self.fc2 = nn.Linear(512, 11) # 11 classes
41
+
42
+ def forward(self, x):
43
+ x = self.pool(torch.relu(self.conv1(x)))
44
+ x = self.pool(torch.relu(self.conv2(x)))
45
+ x = self.pool(torch.relu(self.conv3(x)))
46
+ x = x.view(-1, 128 * 16 * 16)
47
+ x = torch.relu(self.fc1(x))
48
+ x = self.fc2(x)
49
+ return x
50
+
51
+ model = WeatherNet()
52
+
53
+
54
+ criterion = nn.CrossEntropyLoss()
55
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
56
+
57
+
58
+ num_epochs = 10
59
+ for epoch in range(num_epochs):
60
+ model.train()
61
+ running_loss = 0.0
62
+ for images, labels in train_loader:
63
+ optimizer.zero_grad()
64
+ outputs = model(images)
65
+ loss = criterion(outputs, labels)
66
+ loss.backward()
67
+ optimizer.step()
68
+ running_loss += loss.item()
69
+ print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
70
+
71
+ model.eval()
72
+ correct = 0
73
+ total = 0
74
+ with torch.no_grad():
75
+ for images, labels in val_loader:
76
+ outputs = model(images)
77
+ _, predicted = torch.max(outputs.data, 1)
78
+ total += labels.size(0)
79
+ correct += (predicted == labels).sum().item()
80
+ print(f'Validation Accuracy: {100 * correct / total:.2f}%')
81
+
82
  import wandb
83
  def predict_image(image):
84
  image = Image.fromarray(image).convert("RGB")