waleeyd commited on
Commit
b0e1c00
·
verified ·
1 Parent(s): 7c27e08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -71
app.py CHANGED
@@ -1,30 +1,21 @@
1
  from flask import Flask, request, render_template, jsonify
2
- from flask_limiter import Limiter
3
- from flask_limiter.util import get_remote_address
4
- from flask_limiter.errors import RateLimitExceeded
5
- from flask_wtf.csrf import CSRFProtect
6
- from flask_wtf import FlaskForm
7
- from wtforms import FileField
8
- from wtforms.validators import DataRequired
9
  from werkzeug.utils import secure_filename
10
- from functools import wraps
11
  from PIL import Image
12
  import numpy as np
13
  import torch
14
  import torchvision.transforms as T
15
- import random
16
  import os
17
  import time
18
  import hashlib
19
  import logging
20
- import secrets
21
 
22
  # ===============================
23
- # Deterministic inference
24
  # ===============================
25
- torch.manual_seed(42)
26
- np.random.seed(42)
27
- random.seed(42)
 
28
 
29
  # ===============================
30
  # App initialization
@@ -32,29 +23,11 @@ random.seed(42)
32
  app = Flask(__name__)
33
 
34
  app.config.update(
35
- SECRET_KEY=secrets.token_hex(32),
36
  MAX_CONTENT_LENGTH=30 * 1024 * 1024, # 30 MB
37
- WTF_CSRF_TIME_LIMIT=None,
38
  UPLOAD_FOLDER="uploads"
39
  )
40
 
41
- csrf = CSRFProtect(app)
42
-
43
- limiter = Limiter(
44
- key_func=get_remote_address,
45
- default_limits=["100 per hour", "20 per minute"]
46
- )
47
- limiter.init_app(app)
48
-
49
- # ===============================
50
- # Logging
51
- # ===============================
52
- logging.basicConfig(
53
- level=logging.INFO,
54
- format="%(asctime)s - %(levelname)s - %(message)s",
55
- handlers=[logging.FileHandler("security.log"), logging.StreamHandler()]
56
- )
57
-
58
  # ===============================
59
  # Load Model
60
  # ===============================
@@ -69,15 +42,16 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
69
  try:
70
  from transformers import SiglipForImageClassification
71
 
 
72
  model = SiglipForImageClassification.from_pretrained(
73
  MODEL_NAME,
74
  torch_dtype=torch.float32
75
  ).to(device)
76
  model.eval()
77
 
78
- logging.info(f"Model loaded successfully on {device}")
79
  except Exception as e:
80
- logging.error(f"Model load failed: {e}")
81
  model = None
82
 
83
  # ===============================
@@ -98,18 +72,8 @@ MAX_FILE_SIZE = 30 * 1024 * 1024
98
  os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
99
 
100
  # ===============================
101
- # Helpers
102
  # ===============================
103
- def security_validate(f):
104
- @wraps(f)
105
- def wrapper(*args, **kwargs):
106
- logging.info(f"{request.remote_addr} → {request.endpoint}")
107
- return f(*args, **kwargs)
108
- return wrapper
109
-
110
- class UploadForm(FlaskForm):
111
- image = FileField("Image", validators=[DataRequired()])
112
-
113
  def allowed_file(filename):
114
  return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
115
 
@@ -181,15 +145,15 @@ def predict_image(path):
181
  # Routes
182
  # ===============================
183
  @app.route("/", methods=["GET"])
184
- @limiter.limit("30 per minute")
185
- @security_validate
186
  def index():
187
- return render_template("index.html", form=UploadForm())
 
 
 
 
 
188
 
189
  @app.route("/predict", methods=["POST"])
190
- @csrf.exempt
191
- @limiter.limit("60 per minute")
192
- @security_validate
193
  def predict():
194
  if "image" not in request.files:
195
  return jsonify({"error": "No image uploaded"}), 400
@@ -223,7 +187,10 @@ def predict():
223
 
224
  finally:
225
  if os.path.exists(path):
226
- os.remove(path)
 
 
 
227
 
228
  # ===============================
229
  # Security headers
@@ -231,30 +198,16 @@ def predict():
231
  @app.after_request
232
  def security_headers(res):
233
  res.headers["X-Content-Type-Options"] = "nosniff"
234
- res.headers["X-Frame-Options"] = "DENY"
235
- res.headers["X-XSS-Protection"] = "1; mode=block"
236
- res.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
237
- res.headers["Content-Security-Policy"] = (
238
- "default-src 'self' https://cdnjs.cloudflare.com; "
239
- "style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com; "
240
- "script-src 'self' 'unsafe-inline'; "
241
- "img-src 'self' data:; "
242
- "font-src 'self' https://cdnjs.cloudflare.com; "
243
- "frame-src https://www.youtube.com https://www.youtube-nocookie.com;"
244
- )
245
  return res
246
 
247
  # ===============================
248
- # JSON Error Handlers
249
  # ===============================
250
  @app.errorhandler(413)
251
  def request_entity_too_large(error):
252
  return jsonify({"error": "File too large. Maximum allowed size is 30 MB."}), 413
253
 
254
- @app.errorhandler(RateLimitExceeded)
255
- def rate_limit_handler(e):
256
- return jsonify({"error": "Too many requests. Please slow down."}), 429
257
-
258
  @app.errorhandler(400)
259
  def bad_request(error):
260
  return jsonify({"error": "Invalid request or image file."}), 400
@@ -267,8 +220,8 @@ def internal_error(error):
267
  # Run
268
  # ===============================
269
  if __name__ == "__main__":
270
- # Try multiple port options for different platforms
271
- port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", 7860)))
272
  app.run(
273
  debug=False,
274
  host="0.0.0.0",
 
1
  from flask import Flask, request, render_template, jsonify
 
 
 
 
 
 
 
2
  from werkzeug.utils import secure_filename
 
3
  from PIL import Image
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms as T
 
7
  import os
8
  import time
9
  import hashlib
10
  import logging
 
11
 
12
  # ===============================
13
+ # Logging
14
  # ===============================
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format="%(asctime)s - %(levelname)s - %(message)s"
18
+ )
19
 
20
  # ===============================
21
  # App initialization
 
23
  app = Flask(__name__)
24
 
25
  app.config.update(
26
+ SECRET_KEY=os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production'),
27
  MAX_CONTENT_LENGTH=30 * 1024 * 1024, # 30 MB
 
28
  UPLOAD_FOLDER="uploads"
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # ===============================
32
  # Load Model
33
  # ===============================
 
42
  try:
43
  from transformers import SiglipForImageClassification
44
 
45
+ logging.info(f"Loading model from {MODEL_NAME}...")
46
  model = SiglipForImageClassification.from_pretrained(
47
  MODEL_NAME,
48
  torch_dtype=torch.float32
49
  ).to(device)
50
  model.eval()
51
 
52
+ logging.info(f"Model loaded successfully on {device}")
53
  except Exception as e:
54
+ logging.error(f"Model load failed: {e}")
55
  model = None
56
 
57
  # ===============================
 
72
  os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
73
 
74
  # ===============================
75
+ # Helper Functions
76
  # ===============================
 
 
 
 
 
 
 
 
 
 
77
  def allowed_file(filename):
78
  return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
79
 
 
145
  # Routes
146
  # ===============================
147
  @app.route("/", methods=["GET"])
 
 
148
  def index():
149
+ return render_template("index.html")
150
+
151
+ @app.route("/health", methods=["GET"])
152
+ def health():
153
+ """Health check endpoint for deployment platforms"""
154
+ return jsonify({"status": "healthy", "model_loaded": model is not None}), 200
155
 
156
  @app.route("/predict", methods=["POST"])
 
 
 
157
  def predict():
158
  if "image" not in request.files:
159
  return jsonify({"error": "No image uploaded"}), 400
 
187
 
188
  finally:
189
  if os.path.exists(path):
190
+ try:
191
+ os.remove(path)
192
+ except:
193
+ pass
194
 
195
  # ===============================
196
  # Security headers
 
198
  @app.after_request
199
  def security_headers(res):
200
  res.headers["X-Content-Type-Options"] = "nosniff"
201
+ res.headers["X-Frame-Options"] = "SAMEORIGIN"
 
 
 
 
 
 
 
 
 
 
202
  return res
203
 
204
  # ===============================
205
+ # Error Handlers
206
  # ===============================
207
  @app.errorhandler(413)
208
  def request_entity_too_large(error):
209
  return jsonify({"error": "File too large. Maximum allowed size is 30 MB."}), 413
210
 
 
 
 
 
211
  @app.errorhandler(400)
212
  def bad_request(error):
213
  return jsonify({"error": "Invalid request or image file."}), 400
 
220
  # Run
221
  # ===============================
222
  if __name__ == "__main__":
223
+ port = int(os.environ.get("PORT", 7860))
224
+ logging.info(f"Starting Flask app on port {port}...")
225
  app.run(
226
  debug=False,
227
  host="0.0.0.0",