Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location | |
| from flask import Flask, request, render_template, jsonify | |
| from transformers import ViTForImageClassification, ViTFeatureExtractor | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import io | |
| app = Flask(__name__) | |
| # Load the ViT model and its feature extractor | |
| model_name = "google/vit-base-patch16-224-in21k" | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
| # Load the trained model weights | |
| num_classes = 7 | |
| model.classifier = nn.Linear(model.config.hidden_size, num_classes) | |
| model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Define class labels | |
| class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] | |
| # Define optimal thresholds | |
| thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854] | |
| # Define image transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| def index(): | |
| return render_template('index.html', appName="Skin Cancer Classification Application") | |
| def model_predict(image): | |
| image = transform(image).unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| return outputs | |
| def api(): | |
| try: | |
| if 'fileup' not in request.files: | |
| return jsonify({'Error': "Please try again. The Image doesn't exist"}) | |
| file = request.files.get('fileup') | |
| image = Image.open(io.BytesIO(file.read())) | |
| result = model_predict(image) | |
| probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0] | |
| predicted_idx = torch.argmax(torch.tensor(probabilities)).item() | |
| max_prob = probabilities[predicted_idx] | |
| threshold = thresholds[predicted_idx] | |
| if max_prob < threshold: | |
| return jsonify({'Error': 'No cancer detected or benign lesion.'}) | |
| prediction = class_labels[predicted_idx] | |
| return jsonify({'prediction': prediction}) | |
| except Exception as e: | |
| return jsonify({'Error': 'An error occurred', 'Message': str(e)}) | |
| def predict(): | |
| if request.method == 'POST': | |
| try: | |
| if 'fileup' not in request.files: | |
| return render_template('index.html', prediction='No file selected.', appName="Skin Cancer Classification Application") | |
| file = request.files['fileup'] | |
| image = Image.open(io.BytesIO(file.read())) | |
| result = model_predict(image) | |
| probabilities = torch.softmax(result.logits, dim=1).cpu().numpy()[0] | |
| predicted_idx = torch.argmax(torch.tensor(probabilities)).item() | |
| max_prob = probabilities[predicted_idx] | |
| threshold = thresholds[predicted_idx] | |
| if max_prob < threshold: | |
| return render_template('index.html', prediction='No cancer detected or benign lesion.', appName="Skin Cancer Classification Application") | |
| prediction = class_labels[predicted_idx] | |
| return render_template('index.html', prediction=prediction, appName="Skin Cancer Classification Application") | |
| except Exception as e: | |
| return render_template('index.html', prediction='Error: ' + str(e), appName="Skin Cancer Classification Application") | |
| else: | |
| return render_template('index.html', appName="Skin Cancer Classification Application") | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |