Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import ViTFeatureExtractor | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| from torchvision import transforms | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| model = None | |
| feature_extractor = None | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| VALID_DS_PATH = 'valid_ds.pth' | |
| valid_ds = torch.load(VALID_DS_PATH) | |
| from transformers import ViTModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ViTForImageClassification(nn.Module): | |
| def __init__(self, num_labels=3): | |
| super(ViTForImageClassification, self).__init__() | |
| self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') | |
| self.dropout = nn.Dropout(0.1) | |
| self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) | |
| self.num_labels = num_labels | |
| def forward(self, pixel_values, labels): | |
| outputs = self.vit(pixel_values=pixel_values) | |
| output = self.dropout(outputs.last_hidden_state[:,0]) | |
| logits = self.classifier(output) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| if loss is not None: | |
| return logits, loss.item() | |
| else: | |
| return logits, None | |
| # Load an image from file for inference | |
| def load_image(image_path): | |
| img = Image.open(image_path) | |
| img = img.convert("RGB") # Ensure it's in RGB format | |
| return img | |
| # Inference function | |
| def run_inference(image, device, valid_ds): | |
| # Load image from the Gradio input | |
| # input_image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| global model, feature_extractor | |
| if model is None or feature_extractor is None: | |
| MODEL_PATH = hf_hub_download(repo_id="limitedonly41/offers_26", | |
| filename="model_50.pt", | |
| use_auth_token=HF_TOKEN) | |
| try: | |
| model = torch.load(MODEL_PATH) | |
| except: | |
| model = torch.load(MODEL_PATH, map_location=torch.device('cpu')) | |
| model.eval() | |
| model.to(device) | |
| # feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize to the model's input size | |
| transforms.ToTensor(), | |
| ]) | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| input_tensor = transform(image) | |
| input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension | |
| input_tensor = input_tensor.to(device) # Send to appropriate computing device | |
| # Disable grad | |
| with torch.no_grad(): | |
| # Generate prediction | |
| prediction, _ = model(input_tensor, labels=None) | |
| # Get the predicted class index | |
| predicted_class = torch.argmax(prediction, dim=1).item() | |
| value_predicted = list(valid_ds.class_to_idx.keys())[list(valid_ds.class_to_idx.values()).index(predicted_class)] | |
| # return f"Predicted Class: {value_predicted}, {predicted_class}" | |
| return value_predicted | |
| # # Preprocess the image using the feature extractor | |
| # inputs = feature_extractor(images=input_image, return_tensors="pt")['pixel_values'] | |
| # # Send to the appropriate device (CPU/GPU) | |
| # inputs = inputs.to(device) | |
| # # Disable gradients during inference | |
| # with torch.no_grad(): | |
| # # Generate prediction | |
| # prediction, _ = model(inputs, None) | |
| # # Predicted class value using argmax | |
| # predicted_class = np.argmax(prediction.cpu().numpy()) | |
| # value_predicted = list(valid_ds.class_to_idx.keys())[list(valid_ds.class_to_idx.values()).index(predicted_class)] | |
| # # Return the result with the predicted class | |
| # return f"Predicted Class: {value_predicted}, {predicted_class}" | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=lambda image: run_inference(image, device, valid_ds), | |
| inputs=gr.Image(type="numpy"), # Updated to use gr.Image | |
| outputs="text", # Output is text (predicted class) | |
| title="Image Classification", | |
| description="Upload an image to get the predicted class using the ViT model." | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |