import os import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image import gradio as gr from transformers import ViTFeatureExtractor from huggingface_hub import hf_hub_download import spaces from torchvision import transforms HF_TOKEN = os.environ.get("HF_TOKEN") model = None feature_extractor = None device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') VALID_DS_PATH = 'valid_ds.pth' valid_ds = torch.load(VALID_DS_PATH) from transformers import ViTModel from transformers.modeling_outputs import SequenceClassifierOutput import torch.nn as nn import torch.nn.functional as F class ViTForImageClassification(nn.Module): def __init__(self, num_labels=3): super(ViTForImageClassification, self).__init__() self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) self.num_labels = num_labels def forward(self, pixel_values, labels): outputs = self.vit(pixel_values=pixel_values) output = self.dropout(outputs.last_hidden_state[:,0]) logits = self.classifier(output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if loss is not None: return logits, loss.item() else: return logits, None # Load an image from file for inference def load_image(image_path): img = Image.open(image_path) img = img.convert("RGB") # Ensure it's in RGB format return img # Inference function @spaces.GPU() def run_inference(image, device, valid_ds): # Load image from the Gradio input # input_image = Image.fromarray(image.astype('uint8'), 'RGB') global model, feature_extractor if model is None or feature_extractor is None: MODEL_PATH = hf_hub_download(repo_id="limitedonly41/offers_26", filename="model_50.pt", use_auth_token=HF_TOKEN) try: model = torch.load(MODEL_PATH) except: model = torch.load(MODEL_PATH, map_location=torch.device('cpu')) model.eval() model.to(device) # feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False) transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to the model's input size transforms.ToTensor(), ]) image = Image.fromarray(image.astype('uint8'), 'RGB') input_tensor = transform(image) input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension input_tensor = input_tensor.to(device) # Send to appropriate computing device # Disable grad with torch.no_grad(): # Generate prediction prediction, _ = model(input_tensor, labels=None) # Get the predicted class index predicted_class = torch.argmax(prediction, dim=1).item() value_predicted = list(valid_ds.class_to_idx.keys())[list(valid_ds.class_to_idx.values()).index(predicted_class)] # return f"Predicted Class: {value_predicted}, {predicted_class}" return value_predicted # # Preprocess the image using the feature extractor # inputs = feature_extractor(images=input_image, return_tensors="pt")['pixel_values'] # # Send to the appropriate device (CPU/GPU) # inputs = inputs.to(device) # # Disable gradients during inference # with torch.no_grad(): # # Generate prediction # prediction, _ = model(inputs, None) # # Predicted class value using argmax # predicted_class = np.argmax(prediction.cpu().numpy()) # value_predicted = list(valid_ds.class_to_idx.keys())[list(valid_ds.class_to_idx.values()).index(predicted_class)] # # Return the result with the predicted class # return f"Predicted Class: {value_predicted}, {predicted_class}" # Create a Gradio interface iface = gr.Interface( fn=lambda image: run_inference(image, device, valid_ds), inputs=gr.Image(type="numpy"), # Updated to use gr.Image outputs="text", # Output is text (predicted class) title="Image Classification", description="Upload an image to get the predicted class using the ViT model." ) # Launch the Gradio app iface.launch()