itsLu commited on
Commit
5b7332b
·
verified ·
1 Parent(s): 06a7d25

Update backend/main.py

Browse files
Files changed (1) hide show
  1. backend/main.py +81 -52
backend/main.py CHANGED
@@ -2,21 +2,18 @@ from __future__ import annotations
2
 
3
  import os
4
  import tempfile
5
- from typing import List, Optional, Tuple
 
6
 
7
  import numpy as np
8
  from fastapi import FastAPI, File, UploadFile, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.staticfiles import StaticFiles
11
- from fastapi.responses import FileResponse, JSONResponse
12
 
13
  import cv2 # type: ignore
14
  import joblib # type: ignore
15
  import tensorflow as tf # type: ignore
16
  from tensorflow.keras.applications.resnet import preprocess_input
17
- import os
18
- os.environ["SM_FRAMEWORK"] = "tf.keras"
19
-
20
 
21
 
22
  # -----------------------------
@@ -24,10 +21,10 @@ os.environ["SM_FRAMEWORK"] = "tf.keras"
24
  # -----------------------------
25
  MODEL_PATH = os.getenv("MODEL_PATH", "/app/model.keras")
26
  SCALER_PATH = os.getenv("SCALER_PATH", "/app/scaler.save")
 
27
 
28
  _model = None
29
  _scaler = None
30
- _preprocess_input = None
31
 
32
 
33
  def get_model():
@@ -39,7 +36,7 @@ def get_model():
39
  "Place your Keras model at /app/model.keras or set MODEL_PATH."
40
  )
41
  # compile=False is safer for inference-only deployments
42
- _model = tf.keras.models.load_model("/app/model.keras", compile=False)
43
  return _model
44
 
45
 
@@ -55,39 +52,34 @@ def get_scaler():
55
  return _scaler
56
 
57
 
58
- def get_preprocess_input():
59
- global _preprocess_input
60
- if _preprocess_input is None:
61
- _preprocess_input = sm.get_preprocessing("resnet101")
62
- return _preprocess_input
63
-
64
-
65
  # -----------------------------
66
  # Preprocessing (as requested)
67
  # -----------------------------
68
- def load_data(image_path):
69
  image = tf.io.read_file(image_path)
70
  image = tf.io.decode_png(image, channels=3)
71
  image = tf.image.resize(image, [224, 224], method="bilinear")
72
  image = tf.cast(image, tf.float32)
73
- image = preprocess_input(image)
74
  return image
75
 
76
 
77
-
 
 
78
  def extract_frames_to_pngs(video_bytes: bytes, max_frames: int = 300) -> List[str]:
79
  """Decode video bytes with OpenCV and write frames as PNGs to a temp dir.
80
-
81
  Returns a list of PNG file paths.
82
  """
83
  tmpdir = tempfile.mkdtemp(prefix="frames_")
84
  video_path = os.path.join(tmpdir, "input.mp4")
 
85
  with open(video_path, "wb") as f:
86
  f.write(video_bytes)
87
 
88
  cap = cv2.VideoCapture(video_path)
89
  if not cap.isOpened():
90
- raise ValueError("Could not open uploaded video.")
91
 
92
  paths: List[str] = []
93
  idx = 0
@@ -96,8 +88,6 @@ def extract_frames_to_pngs(video_bytes: bytes, max_frames: int = 300) -> List[st
96
  if not ok:
97
  break
98
 
99
- # Ensure RGB->BGR handling: OpenCV reads BGR; we'll write PNG in BGR which is fine,
100
- # because tf.io.decode_png reads the encoded pixels and we treat them as 3 channels.
101
  out_path = os.path.join(tmpdir, f"frame_{idx:05d}.png")
102
  cv2.imwrite(out_path, frame)
103
  paths.append(out_path)
@@ -110,14 +100,16 @@ def extract_frames_to_pngs(video_bytes: bytes, max_frames: int = 300) -> List[st
110
  return paths
111
 
112
 
 
 
 
113
  def moving_average(x: np.ndarray, window: int = 7) -> np.ndarray:
114
  if window <= 1:
115
  return x
116
- window = min(window, x.shape[0])
117
  kernel = np.ones(window, dtype=np.float32) / float(window)
118
- # pad to keep length
119
  pad = window // 2
120
- xpad = np.pad(x, (pad, pad), mode="edge")
121
  return np.convolve(xpad, kernel, mode="valid")
122
 
123
 
@@ -128,30 +120,69 @@ def compute_ef(edv: float, esv: float) -> float:
128
 
129
 
130
  def classify_heart_function(ef: float) -> str:
131
- """Thresholds for EF-based function.
132
-
133
- You asked for best thresholds; without calibration data, we use standard clinical cutoffs:
134
- - Normal: >= 55
135
- - Mild dysfunction: 40–54
136
- - Heart failure: < 40
137
- """
138
  if not np.isfinite(ef):
139
  return "heart failure"
140
-
141
  if ef >= 55.0:
142
  return "normal"
143
  if ef >= 40.0:
144
- # Frontend expects this exact string union.
145
  return "mildly dysfunction"
146
  return "heart failure"
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  # -----------------------------
150
  # FastAPI app
151
  # -----------------------------
152
  app = FastAPI()
153
 
154
- # Same-origin by default; allow all just in case spaces routes differ
155
  app.add_middleware(
156
  CORSMiddleware,
157
  allow_origins=["*"],
@@ -160,38 +191,34 @@ app.add_middleware(
160
  allow_headers=["*"],
161
  )
162
 
163
- STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
164
-
165
 
166
  @app.post("/api/analyze")
167
  async def analyze(video: UploadFile = File(...)):
168
- if video.content_type is None or "video" not in video.content_type:
169
- # Some browsers may omit; still attempt
170
- pass
171
-
172
  video_bytes = await video.read()
173
  if not video_bytes:
174
  raise HTTPException(status_code=400, detail="Empty video upload.")
175
 
176
  try:
177
- frame_paths = extract_frames_to_pngs(video_bytes)
178
- except Exception as e:
179
- raise HTTPException(status_code=400, detail=f"Video decode error: {e}")
 
180
 
181
- try:
182
  # Build a batch tensor [N, 224, 224, 3]
183
  batch = tf.stack([load_data(p) for p in frame_paths], axis=0)
 
184
  model = get_model()
185
- preds = model.predict(batch, verbose=0)
186
 
187
- preds_np = np.array(preds)
188
- preds_np = preds_np.reshape(-1, 1)
189
 
190
  scaler = get_scaler()
191
- values = scaler.inverse_transform(preds_np).reshape(-1)
192
 
193
  # Smooth
194
- smooth = moving_average(values, window=int(os.getenv("SMOOTH_WINDOW", "7")))
 
 
195
  edv = float(np.max(smooth))
196
  esv = float(np.min(smooth))
197
  ef = compute_ef(edv, esv)
@@ -202,11 +229,13 @@ async def analyze(video: UploadFile = File(...)):
202
  "heartFunction": heart_fn,
203
  "edv": round(edv, 2),
204
  "esv": round(esv, 2),
205
- "numFrames": int(len(values)),
206
  }
207
- except RuntimeError as e:
208
- raise HTTPException(status_code=500, detail=str(e))
 
209
  except Exception as e:
 
210
  raise HTTPException(status_code=500, detail=f"Inference error: {e}")
211
 
212
 
 
2
 
3
  import os
4
  import tempfile
5
+ import traceback
6
+ from typing import List
7
 
8
  import numpy as np
9
  from fastapi import FastAPI, File, UploadFile, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
 
12
 
13
  import cv2 # type: ignore
14
  import joblib # type: ignore
15
  import tensorflow as tf # type: ignore
16
  from tensorflow.keras.applications.resnet import preprocess_input
 
 
 
17
 
18
 
19
  # -----------------------------
 
21
  # -----------------------------
22
  MODEL_PATH = os.getenv("MODEL_PATH", "/app/model.keras")
23
  SCALER_PATH = os.getenv("SCALER_PATH", "/app/scaler.save")
24
+ STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
25
 
26
  _model = None
27
  _scaler = None
 
28
 
29
 
30
  def get_model():
 
36
  "Place your Keras model at /app/model.keras or set MODEL_PATH."
37
  )
38
  # compile=False is safer for inference-only deployments
39
+ _model = tf.keras.models.load_model(MODEL_PATH, compile=False)
40
  return _model
41
 
42
 
 
52
  return _scaler
53
 
54
 
 
 
 
 
 
 
 
55
  # -----------------------------
56
  # Preprocessing (as requested)
57
  # -----------------------------
58
+ def load_data(image_path: str) -> tf.Tensor:
59
  image = tf.io.read_file(image_path)
60
  image = tf.io.decode_png(image, channels=3)
61
  image = tf.image.resize(image, [224, 224], method="bilinear")
62
  image = tf.cast(image, tf.float32)
63
+ image = preprocess_input(image) # ResNet preprocessing
64
  return image
65
 
66
 
67
+ # -----------------------------
68
+ # Video -> frames
69
+ # -----------------------------
70
  def extract_frames_to_pngs(video_bytes: bytes, max_frames: int = 300) -> List[str]:
71
  """Decode video bytes with OpenCV and write frames as PNGs to a temp dir.
 
72
  Returns a list of PNG file paths.
73
  """
74
  tmpdir = tempfile.mkdtemp(prefix="frames_")
75
  video_path = os.path.join(tmpdir, "input.mp4")
76
+
77
  with open(video_path, "wb") as f:
78
  f.write(video_bytes)
79
 
80
  cap = cv2.VideoCapture(video_path)
81
  if not cap.isOpened():
82
+ raise ValueError("Could not open uploaded video. (Unsupported codec/container?)")
83
 
84
  paths: List[str] = []
85
  idx = 0
 
88
  if not ok:
89
  break
90
 
 
 
91
  out_path = os.path.join(tmpdir, f"frame_{idx:05d}.png")
92
  cv2.imwrite(out_path, frame)
93
  paths.append(out_path)
 
100
  return paths
101
 
102
 
103
+ # -----------------------------
104
+ # Post-processing helpers
105
+ # -----------------------------
106
  def moving_average(x: np.ndarray, window: int = 7) -> np.ndarray:
107
  if window <= 1:
108
  return x
109
+ window = int(max(1, min(window, x.shape[0])))
110
  kernel = np.ones(window, dtype=np.float32) / float(window)
 
111
  pad = window // 2
112
+ xpad = np.pad(x.astype(np.float32), (pad, pad), mode="edge")
113
  return np.convolve(xpad, kernel, mode="valid")
114
 
115
 
 
120
 
121
 
122
  def classify_heart_function(ef: float) -> str:
 
 
 
 
 
 
 
123
  if not np.isfinite(ef):
124
  return "heart failure"
 
125
  if ef >= 55.0:
126
  return "normal"
127
  if ef >= 40.0:
 
128
  return "mildly dysfunction"
129
  return "heart failure"
130
 
131
 
132
+ def _normalize_model_output(raw, n_frames: int) -> np.ndarray:
133
+ """
134
+ Normalize model.predict output to shape (N, 1) float array suitable for scaler.inverse_transform.
135
+ Handles models that return:
136
+ - single array: (N,), (N,1), (N,k), (N,1,1), etc.
137
+ - list/tuple of arrays (multi-output)
138
+ """
139
+ # Multi-output model: choose the output whose first dim matches number of frames.
140
+ if isinstance(raw, (list, tuple)):
141
+ shapes = [np.asarray(x).shape for x in raw]
142
+ print("PRED LIST SHAPES:", shapes)
143
+
144
+ chosen = None
145
+ for r in raw:
146
+ r_arr = np.asarray(r)
147
+ if r_arr.ndim >= 1 and r_arr.shape[0] == n_frames:
148
+ chosen = r_arr
149
+ break
150
+ if chosen is None:
151
+ chosen = np.asarray(raw[0])
152
+ raw_arr = chosen
153
+ else:
154
+ raw_arr = np.asarray(raw)
155
+ print("PRED SHAPE:", raw_arr.shape)
156
+
157
+ raw_arr = np.asarray(raw_arr)
158
+
159
+ # Force to (N, 1)
160
+ if raw_arr.ndim == 1:
161
+ raw_arr = raw_arr.reshape(-1, 1)
162
+ elif raw_arr.ndim == 2:
163
+ if raw_arr.shape[0] != n_frames:
164
+ # Sometimes outputs come as (1, N) — fix that
165
+ if raw_arr.shape[1] == n_frames:
166
+ raw_arr = raw_arr.T
167
+ # If multiple columns, pick the first by default
168
+ if raw_arr.shape[1] != 1:
169
+ raw_arr = raw_arr[:, :1]
170
+ else:
171
+ # Flatten everything but the frame dimension
172
+ raw_arr = raw_arr.reshape(raw_arr.shape[0], -1)
173
+ raw_arr = raw_arr[:, :1]
174
+
175
+ if raw_arr.shape[0] != n_frames:
176
+ raise ValueError(f"Prediction length mismatch: got {raw_arr.shape[0]} but expected {n_frames}")
177
+
178
+ return raw_arr.astype(np.float32)
179
+
180
+
181
  # -----------------------------
182
  # FastAPI app
183
  # -----------------------------
184
  app = FastAPI()
185
 
 
186
  app.add_middleware(
187
  CORSMiddleware,
188
  allow_origins=["*"],
 
191
  allow_headers=["*"],
192
  )
193
 
 
 
194
 
195
  @app.post("/api/analyze")
196
  async def analyze(video: UploadFile = File(...)):
 
 
 
 
197
  video_bytes = await video.read()
198
  if not video_bytes:
199
  raise HTTPException(status_code=400, detail="Empty video upload.")
200
 
201
  try:
202
+ frame_paths = extract_frames_to_pngs(
203
+ video_bytes,
204
+ max_frames=int(os.getenv("MAX_FRAMES", "300")),
205
+ )
206
 
 
207
  # Build a batch tensor [N, 224, 224, 3]
208
  batch = tf.stack([load_data(p) for p in frame_paths], axis=0)
209
+
210
  model = get_model()
211
+ raw_preds = model.predict(batch, verbose=0)
212
 
213
+ preds_np = _normalize_model_output(raw_preds, n_frames=batch.shape[0])
 
214
 
215
  scaler = get_scaler()
216
+ values = scaler.inverse_transform(preds_np).reshape(-1).astype(np.float32)
217
 
218
  # Smooth
219
+ smooth_window = int(os.getenv("SMOOTH_WINDOW", "7"))
220
+ smooth = moving_average(values, window=smooth_window)
221
+
222
  edv = float(np.max(smooth))
223
  esv = float(np.min(smooth))
224
  ef = compute_ef(edv, esv)
 
229
  "heartFunction": heart_fn,
230
  "edv": round(edv, 2),
231
  "esv": round(esv, 2),
232
+ "numFrames": int(values.shape[0]),
233
  }
234
+
235
+ except HTTPException:
236
+ raise
237
  except Exception as e:
238
+ print("ANALYZE ERROR TRACEBACK:\n", traceback.format_exc())
239
  raise HTTPException(status_code=500, detail=f"Inference error: {e}")
240
 
241