File size: 2,306 Bytes
0ee8085
5078ec4
 
 
 
 
 
 
 
 
 
 
1f9d19a
49519f3
5078ec4
f19e145
5078ec4
 
 
 
 
 
 
d959a95
5078ec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f19e145
 
5078ec4
f19e145
514ea67
0ee8085
 
 
d959a95
0ee8085
 
 
 
 
 
 
 
 
 
 
d959a95
f19e145
 
 
 
 
 
0ee8085
f19e145
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn.functional as F


transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
])

#Remember peak acc was 72%
class WeatherNet(nn.Module):
    def __init__(self):
        super(WeatherNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 11)  # 11 classes

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 128 * 16 * 16)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = WeatherNet()

model.load_state_dict(torch.load('weather_model.pth'))
model.eval()

classes = ['dew', 'fog_smog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']

def predict_image(image):
    image = Image.fromarray(image).convert("RGB")
    image = transform(image)
    image = image.unsqueeze(0)

    with torch.no_grad():
        outputs = model(image)
        probabilities = F.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)

    predicted_label = classes[predicted.item()]
    predicted_probability = probabilities[0][predicted.item()].item()

    return predicted_label, predicted_probability

#Gradio interface
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.components.Image(),
    outputs=[gr.components.Textbox(label="Predicted Label"), gr.components.Textbox(label="Prediction Probability")],
    title="Weather Detection App"
)

interface.launch()