|
|
| from flask import Flask, request, jsonify |
| from transformers import AutoModelForImageClassification |
| from PIL import Image |
| import torch |
| import torchvision.transforms as transforms |
|
|
| app = Flask(__name__) |
|
|
| |
| model_name = "SanketJadhav/PlantDiseaseClassifier-Resnet50" |
| try: |
| model = AutoModelForImageClassification.from_pretrained( |
| model_name, |
| use_safetensors=False |
| ) |
| print(f"✅ Model '{model_name}' loaded successfully") |
| except Exception as e: |
| print(f"❌ Error loading model: {e}") |
| model = None |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| @app.route("/predict", methods=["POST"]) |
| def predict(): |
| if model is None: |
| return jsonify({"error": "Model not loaded"}), 500 |
|
|
| if "image" not in request.files: |
| return jsonify({"error": "No image file provided"}), 400 |
|
|
| try: |
| file = request.files["image"] |
| image = Image.open(file.stream).convert("RGB") |
|
|
| inputs = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| outputs = model(inputs) |
| logits = outputs.logits |
| predicted_class = logits.argmax(-1).item() |
|
|
| label = model.config.id2label[predicted_class] |
| return jsonify({"prediction": label}) |
|
|
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
| if __name__ == "__main__": |
| |
| app.run(host="0.0.0.0", port=7860) |
|
|
|
|