gajavegs commited on
Commit
b2b7ddc
·
verified ·
1 Parent(s): 8a2e6e1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+ from flask import Flask, jsonify, request, send_from_directory
4
+ from PIL import Image
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from dotenv import load_dotenv
8
+ from model_loader import load_alexnet_model, preprocess_image
9
+
10
+ load_dotenv(override=True)
11
+
12
+ # HF sets PORT dynamically. Fall back to 7860 locally.
13
+ PORT = int(os.getenv("PORT", os.getenv("FLASK_PORT", "7860")))
14
+ HOST = "0.0.0.0"
15
+ MODEL_PATH = os.getenv("MODEL_PATH", "models/alexnext_vsf_bext.pth")
16
+
17
+ # Single worker is safest for GPU inference
18
+ torch.set_num_threads(1)
19
+
20
+ # Create app and static hosting
21
+ app = Flask(__name__, static_folder="static", static_url_path="")
22
+
23
+ # Device selection
24
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # Load model once at startup
27
+ model, classes = load_alexnet_model(MODEL_PATH, device=DEVICE)
28
+ model.to(DEVICE).eval()
29
+
30
+ @app.get("/")
31
+ def root() -> Any:
32
+ # serve your frontend
33
+ return send_from_directory(app.static_folder, "index.html")
34
+
35
+ @app.get("/health")
36
+ def health() -> Any:
37
+ return jsonify({"status": "ok", "device": str(DEVICE)})
38
+
39
+ def load_image(file_stream):
40
+ return Image.open(file_stream).convert("RGB")
41
+
42
+ @app.post("/predict_AlexNet")
43
+ def predict_alexnet() -> Any:
44
+ if "image" not in request.files:
45
+ return jsonify({"error": "Missing file field 'image'."}), 400
46
+
47
+ file = request.files["image"]
48
+ if not file:
49
+ return jsonify({"error": "Empty file."}), 400
50
+
51
+ try:
52
+ img = load_image(file.stream)
53
+ input_tensor = preprocess_image(img).to(DEVICE)
54
+
55
+ with torch.no_grad():
56
+ output = model(input_tensor)
57
+ probabilities = F.softmax(output[0], dim=0).detach().cpu()
58
+
59
+ pred_prob, pred_idx = torch.max(probabilities, dim=0)
60
+ predicted_class = classes[int(pred_idx)]
61
+
62
+ result = {
63
+ "class": predicted_class,
64
+ "confidence": float(pred_prob),
65
+ "probabilities": {
66
+ cls: float(prob) for cls, prob in zip(classes, probabilities.tolist())
67
+ },
68
+ }
69
+ return jsonify(result)
70
+
71
+ except Exception as e:
72
+ return jsonify({"error": f"Failed to process image: {e}"}), 400
73
+
74
+ if __name__ == "__main__":
75
+ debug = bool(int(os.getenv("FLASK_DEBUG", "0")))
76
+ app.run(host=HOST, port=PORT, debug=debug)