File size: 4,546 Bytes
757208b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import DeiTImageProcessor, DeiTForImageClassification
from PIL import Image
import io
import base64
import json
import numpy as np
app = Flask(__name__)
CORS(app) # Enable CORS for Flutter web support
class WaterClassificationModel:
def __init__(self, model_id='durgaprasad143/water-classification-deit'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"π Loading model from Hub: {model_id}...")
try:
self.processor = DeiTImageProcessor.from_pretrained(model_id)
self.model = DeiTForImageClassification.from_pretrained(model_id)
self.model.to(self.device)
self.model.eval()
print(f"β
Model loaded from {model_id}")
except Exception as e:
print(f"β Failed to load model: {e}")
raise e
def preprocess_image(self, image_bytes):
"""Preprocess image for model input using DeiT processor"""
# Open image from bytes
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Use the DeiT processor (same as training)
inputs = self.processor(images=image, return_tensors="pt")
return inputs['pixel_values'].to(self.device)
def predict(self, image_bytes):
"""Make prediction on image"""
try:
# Preprocess image
input_tensor = self.preprocess_image(image_bytes)
# Make prediction
with torch.no_grad():
outputs = self.model(input_tensor).logits
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Map to labels
class_names = ['hazardous', 'non_hazardous']
prediction = class_names[predicted_class]
return {
'prediction': prediction,
'confidence': confidence,
'probabilities': probabilities[0].cpu().numpy().tolist()
}
except Exception as e:
print(f"β Prediction error: {e}")
return {
'error': str(e),
'prediction': 'unknown',
'confidence': 0.0
}
# Initialize model
model = WaterClassificationModel()
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({'status': 'healthy', 'model_loaded': True})
@app.route('/predict', methods=['POST'])
def predict():
"""Prediction endpoint"""
try:
# Get image from request
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
image_file = request.files['image']
image_bytes = image_file.read()
if not image_bytes:
return jsonify({'error': 'Empty image'}), 400
# Make prediction
result = model.predict(image_bytes)
return jsonify(result)
except Exception as e:
print(f"β API Error: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/predict_base64', methods=['POST'])
def predict_base64():
"""Prediction endpoint for base64 encoded images"""
try:
data = request.get_json()
if not data or 'image' not in data:
return jsonify({'error': 'No image provided'}), 400
# Decode base64 image
image_data = data['image']
if ',' in image_data:
image_data = image_data.split(',')[1] # Remove data URL prefix
image_bytes = base64.b64decode(image_data)
# Make prediction
result = model.predict(image_bytes)
return jsonify(result)
except Exception as e:
print(f"β API Error: {e}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
print("π Starting Water Classification API...")
print("π‘ Available endpoints:")
print(" GET /health")
print(" POST /predict (multipart form data)")
print(" POST /predict_base64 (JSON with base64 image)")
print("π Server running on http://localhost:5000")
app.run(host='0.0.0.0', port=5000, debug=True)
|