Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location | |
| from fastapi import FastAPI, UploadFile, File | |
| from transformers import ViTForImageClassification, ViTFeatureExtractor | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import io | |
| app = FastAPI() | |
| # Load the ViT model and its feature extractor | |
| model_name = "google/vit-base-patch16-224-in21k" | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
| # Load the trained model weights | |
| num_classes = 7 | |
| model.classifier = nn.Linear(model.config.hidden_size, num_classes) | |
| # Load the trained weights | |
| model.load_state_dict(torch.load("models/Anwarkh1/Skin_Cancer-Image_Classification", map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Define class labels | |
| class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] | |
| # Define image transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Define API endpoint for model inference | |
| async def predict(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| image = transform(image).unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| # Calculate softmax probabilities | |
| probabilities = torch.softmax(outputs.logits, dim=1) | |
| # Get predicted class index and its probability | |
| predicted_idx = torch.argmax(probabilities).item() | |
| predicted_label = class_labels[predicted_idx] | |
| predicted_accuracy = probabilities[0][predicted_idx].item() | |
| return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy} | |