WeatherRacoon / app.py
yukeshwaradse's picture
Update app.py
d959a95 verified
import gradio as gr
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn.functional as F
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
])
#Remember peak acc was 72%
class WeatherNet(nn.Module):
def __init__(self):
super(WeatherNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(128 * 16 * 16, 512)
self.fc2 = nn.Linear(512, 11) # 11 classes
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = self.pool(torch.relu(self.conv3(x)))
x = x.view(-1, 128 * 16 * 16)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = WeatherNet()
model.load_state_dict(torch.load('weather_model.pth'))
model.eval()
classes = ['dew', 'fog_smog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']
def predict_image(image):
image = Image.fromarray(image).convert("RGB")
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
outputs = model(image)
probabilities = F.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
predicted_label = classes[predicted.item()]
predicted_probability = probabilities[0][predicted.item()].item()
return predicted_label, predicted_probability
#Gradio interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.components.Image(),
outputs=[gr.components.Textbox(label="Predicted Label"), gr.components.Textbox(label="Prediction Probability")],
title="Weather Detection App"
)
interface.launch()