import gradio as gr import wandb import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader import matplotlib.pyplot as plt import torchvision import matplotlib.pyplot as plt import numpy as np from PIL import Image import torch.nn.functional as F transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), ]) #Remember peak acc was 72% class WeatherNet(nn.Module): def __init__(self): super(WeatherNet, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) self.fc1 = nn.Linear(128 * 16 * 16, 512) self.fc2 = nn.Linear(512, 11) # 11 classes def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = self.pool(torch.relu(self.conv3(x))) x = x.view(-1, 128 * 16 * 16) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x model = WeatherNet() model.load_state_dict(torch.load('weather_model.pth')) model.eval() classes = ['dew', 'fog_smog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow'] def predict_image(image): image = Image.fromarray(image).convert("RGB") image = transform(image) image = image.unsqueeze(0) with torch.no_grad(): outputs = model(image) probabilities = F.softmax(outputs, dim=1) _, predicted = torch.max(outputs, 1) predicted_label = classes[predicted.item()] predicted_probability = probabilities[0][predicted.item()].item() return predicted_label, predicted_probability #Gradio interface interface = gr.Interface( fn=predict_image, inputs=gr.components.Image(), outputs=[gr.components.Textbox(label="Predicted Label"), gr.components.Textbox(label="Prediction Probability")], title="Weather Detection App" ) interface.launch()