File size: 4,546 Bytes
757208b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import DeiTImageProcessor, DeiTForImageClassification
from PIL import Image
import io
import base64
import json
import numpy as np

app = Flask(__name__)
CORS(app)  # Enable CORS for Flutter web support

class WaterClassificationModel:
    def __init__(self, model_id='durgaprasad143/water-classification-deit'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"πŸ”„ Loading model from Hub: {model_id}...")
        
        try:
            self.processor = DeiTImageProcessor.from_pretrained(model_id)
            self.model = DeiTForImageClassification.from_pretrained(model_id)
            self.model.to(self.device)
            self.model.eval()
            print(f"βœ… Model loaded from {model_id}")
        except Exception as e:
            print(f"❌ Failed to load model: {e}")
            raise e

    def preprocess_image(self, image_bytes):
        """Preprocess image for model input using DeiT processor"""
        # Open image from bytes
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')

        # Use the DeiT processor (same as training)
        inputs = self.processor(images=image, return_tensors="pt")
        return inputs['pixel_values'].to(self.device)

    def predict(self, image_bytes):
        """Make prediction on image"""
        try:
            # Preprocess image
            input_tensor = self.preprocess_image(image_bytes)

            # Make prediction
            with torch.no_grad():
                outputs = self.model(input_tensor).logits
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                predicted_class = torch.argmax(probabilities, dim=1).item()
                confidence = probabilities[0][predicted_class].item()

            # Map to labels
            class_names = ['hazardous', 'non_hazardous']
            prediction = class_names[predicted_class]

            return {
                'prediction': prediction,
                'confidence': confidence,
                'probabilities': probabilities[0].cpu().numpy().tolist()
            }

        except Exception as e:
            print(f"❌ Prediction error: {e}")
            return {
                'error': str(e),
                'prediction': 'unknown',
                'confidence': 0.0
            }

# Initialize model
model = WaterClassificationModel()

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({'status': 'healthy', 'model_loaded': True})

@app.route('/predict', methods=['POST'])
def predict():
    """Prediction endpoint"""
    try:
        # Get image from request
        if 'image' not in request.files:
            return jsonify({'error': 'No image provided'}), 400

        image_file = request.files['image']
        image_bytes = image_file.read()

        if not image_bytes:
            return jsonify({'error': 'Empty image'}), 400

        # Make prediction
        result = model.predict(image_bytes)

        return jsonify(result)

    except Exception as e:
        print(f"❌ API Error: {e}")
        return jsonify({'error': str(e)}), 500

@app.route('/predict_base64', methods=['POST'])
def predict_base64():
    """Prediction endpoint for base64 encoded images"""
    try:
        data = request.get_json()

        if not data or 'image' not in data:
            return jsonify({'error': 'No image provided'}), 400

        # Decode base64 image
        image_data = data['image']
        if ',' in image_data:
            image_data = image_data.split(',')[1]  # Remove data URL prefix

        image_bytes = base64.b64decode(image_data)

        # Make prediction
        result = model.predict(image_bytes)

        return jsonify(result)

    except Exception as e:
        print(f"❌ API Error: {e}")
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    print("πŸš€ Starting Water Classification API...")
    print("πŸ“‘ Available endpoints:")
    print("   GET  /health")
    print("   POST /predict (multipart form data)")
    print("   POST /predict_base64 (JSON with base64 image)")
    print("🌐 Server running on http://localhost:5000")

    app.run(host='0.0.0.0', port=5000, debug=True)