Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torchvision.models as models | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| # Define the VisionTransformer model class | |
| class VisionTransformer(nn.Module): | |
| def __init__(self, num_classes): | |
| super(VisionTransformer, self).__init__() | |
| self.model = models.vit_b_16(weights=None) # Initialize without weights | |
| self.model.heads.head = nn.Linear(self.model.heads.head.in_features, num_classes) | |
| def forward(self, X): | |
| return self.model(X) | |
| # Function to load the model | |
| def load_model(model_path, num_classes): | |
| model = VisionTransformer(num_classes=num_classes) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() # Set the model to evaluation mode | |
| return model | |
| # Preprocess the input image | |
| def preprocess_image(image_path): | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image_path).convert("RGB") | |
| return transform(image).unsqueeze(0) # Add batch dimension | |
| # Dictionary mapping short names to full cloud names | |
| cloud_name_mapping = { | |
| 'AC': 'Altocumulus', | |
| 'As': 'Altostratus', | |
| 'Cb': 'Cumulonimbus', | |
| 'Cc': 'Cirrocumulus', | |
| 'Ci': 'Cirrus', | |
| 'Cs': 'Cirrostratus', | |
| 'Ct': 'Contrails', | |
| 'Cu': 'Cumulus', | |
| 'Ns': 'Nimbostratus', | |
| 'Sc': 'Stratocumulus', | |
| 'St': 'Stratus' | |
| } | |
| # Function to make a prediction | |
| def predict(image_path, model, class_names): | |
| image_tensor = preprocess_image(image_path) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| _, predicted = torch.max(outputs, 1) | |
| confidence = probabilities[0][predicted.item()].item() | |
| predicted_class = class_names[predicted.item()] | |
| full_name = cloud_name_mapping.get(predicted_class, "Unknown") | |
| return full_name, confidence | |
| # Load the model | |
| model_path = "VisionTransformer_with_crop_final_model.pth" # Replace with your model path | |
| class_names = ['AC','As','Cb','Cc','Ci','Cs','Ct','Cu','Ns','Sc','St'] # Replace with your actual class names | |
| model = load_model(model_path, num_classes=len(class_names)) | |
| # Gradio interface | |
| def classify_image(image_path): | |
| predicted_class, confidence = predict(image_path, model, class_names) | |
| return f"Prediction: {predicted_class}\nConfidence: {confidence:.2f}" | |
| gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="filepath"), | |
| outputs="text", | |
| title="Vision Transformer Image Classification" | |
| ).launch() | |