durgaprasad143 commited on
Commit
757208b
Β·
verified Β·
1 Parent(s): b496671

Upload flask_api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. flask_api.py +135 -0
flask_api.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import transforms
6
+ from transformers import DeiTImageProcessor, DeiTForImageClassification
7
+ from PIL import Image
8
+ import io
9
+ import base64
10
+ import json
11
+ import numpy as np
12
+
13
+ app = Flask(__name__)
14
+ CORS(app) # Enable CORS for Flutter web support
15
+
16
+ class WaterClassificationModel:
17
+ def __init__(self, model_id='durgaprasad143/water-classification-deit'):
18
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ print(f"πŸ”„ Loading model from Hub: {model_id}...")
20
+
21
+ try:
22
+ self.processor = DeiTImageProcessor.from_pretrained(model_id)
23
+ self.model = DeiTForImageClassification.from_pretrained(model_id)
24
+ self.model.to(self.device)
25
+ self.model.eval()
26
+ print(f"βœ… Model loaded from {model_id}")
27
+ except Exception as e:
28
+ print(f"❌ Failed to load model: {e}")
29
+ raise e
30
+
31
+ def preprocess_image(self, image_bytes):
32
+ """Preprocess image for model input using DeiT processor"""
33
+ # Open image from bytes
34
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
35
+
36
+ # Use the DeiT processor (same as training)
37
+ inputs = self.processor(images=image, return_tensors="pt")
38
+ return inputs['pixel_values'].to(self.device)
39
+
40
+ def predict(self, image_bytes):
41
+ """Make prediction on image"""
42
+ try:
43
+ # Preprocess image
44
+ input_tensor = self.preprocess_image(image_bytes)
45
+
46
+ # Make prediction
47
+ with torch.no_grad():
48
+ outputs = self.model(input_tensor).logits
49
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
50
+ predicted_class = torch.argmax(probabilities, dim=1).item()
51
+ confidence = probabilities[0][predicted_class].item()
52
+
53
+ # Map to labels
54
+ class_names = ['hazardous', 'non_hazardous']
55
+ prediction = class_names[predicted_class]
56
+
57
+ return {
58
+ 'prediction': prediction,
59
+ 'confidence': confidence,
60
+ 'probabilities': probabilities[0].cpu().numpy().tolist()
61
+ }
62
+
63
+ except Exception as e:
64
+ print(f"❌ Prediction error: {e}")
65
+ return {
66
+ 'error': str(e),
67
+ 'prediction': 'unknown',
68
+ 'confidence': 0.0
69
+ }
70
+
71
+ # Initialize model
72
+ model = WaterClassificationModel()
73
+
74
+ @app.route('/health', methods=['GET'])
75
+ def health_check():
76
+ """Health check endpoint"""
77
+ return jsonify({'status': 'healthy', 'model_loaded': True})
78
+
79
+ @app.route('/predict', methods=['POST'])
80
+ def predict():
81
+ """Prediction endpoint"""
82
+ try:
83
+ # Get image from request
84
+ if 'image' not in request.files:
85
+ return jsonify({'error': 'No image provided'}), 400
86
+
87
+ image_file = request.files['image']
88
+ image_bytes = image_file.read()
89
+
90
+ if not image_bytes:
91
+ return jsonify({'error': 'Empty image'}), 400
92
+
93
+ # Make prediction
94
+ result = model.predict(image_bytes)
95
+
96
+ return jsonify(result)
97
+
98
+ except Exception as e:
99
+ print(f"❌ API Error: {e}")
100
+ return jsonify({'error': str(e)}), 500
101
+
102
+ @app.route('/predict_base64', methods=['POST'])
103
+ def predict_base64():
104
+ """Prediction endpoint for base64 encoded images"""
105
+ try:
106
+ data = request.get_json()
107
+
108
+ if not data or 'image' not in data:
109
+ return jsonify({'error': 'No image provided'}), 400
110
+
111
+ # Decode base64 image
112
+ image_data = data['image']
113
+ if ',' in image_data:
114
+ image_data = image_data.split(',')[1] # Remove data URL prefix
115
+
116
+ image_bytes = base64.b64decode(image_data)
117
+
118
+ # Make prediction
119
+ result = model.predict(image_bytes)
120
+
121
+ return jsonify(result)
122
+
123
+ except Exception as e:
124
+ print(f"❌ API Error: {e}")
125
+ return jsonify({'error': str(e)}), 500
126
+
127
+ if __name__ == '__main__':
128
+ print("πŸš€ Starting Water Classification API...")
129
+ print("πŸ“‘ Available endpoints:")
130
+ print(" GET /health")
131
+ print(" POST /predict (multipart form data)")
132
+ print(" POST /predict_base64 (JSON with base64 image)")
133
+ print("🌐 Server running on http://localhost:5000")
134
+
135
+ app.run(host='0.0.0.0', port=5000, debug=True)