abesh-meena commited on
Commit
0970951
·
verified ·
1 Parent(s): 54d0c79

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +79 -83
app/main.py CHANGED
@@ -68,31 +68,10 @@ logger = logging.getLogger("deepguard_api")
68
  if not logger.handlers:
69
  logging.basicConfig(level=logging.INFO)
70
 
71
- # Try to load an IMAGE model if present
72
- # Default to the trained hybrid model file name; can be overridden via MODEL_PATH env var.
73
- # Always resolve the path relative to the project root so that cwd doesn't matter.
74
- _raw_model_path = os.environ.get("MODEL_PATH", "hybrid_deepfake_model.h5")
75
- if os.path.isabs(_raw_model_path):
76
- MODEL_PATH = _raw_model_path
77
- else:
78
- MODEL_PATH = os.path.join(ROOT_DIR, _raw_model_path)
79
-
80
- _loaded_model = None
81
- _model_load_error: Optional[str] = None
82
- try:
83
- if os.path.exists(MODEL_PATH):
84
- logger.info(f"Loading IMAGE model from: {MODEL_PATH}")
85
- _loaded_model = model_module.load_model_from_checkpoint(MODEL_PATH)
86
- logger.info(f"Image model loaded successfully: name={getattr(_loaded_model, 'name', 'unknown')}")
87
- else:
88
- _model_load_error = f"Model file not found at path: {MODEL_PATH}"
89
- logger.error(_model_load_error)
90
- except Exception as e:
91
- _loaded_model = None
92
- _model_load_error = f"Error loading model from '{MODEL_PATH}': {e}"
93
- logger.exception(_model_load_error)
94
 
95
- # Optional VIDEO model + feature extractor
96
  _raw_video_model_path = os.environ.get("VIDEO_MODEL_PATH", "video_deepfake_model.h5")
97
  if os.path.isabs(_raw_video_model_path):
98
  VIDEO_MODEL_PATH = _raw_video_model_path
@@ -106,12 +85,32 @@ _video_feature_extractor = None
106
  try:
107
  if os.path.exists(VIDEO_MODEL_PATH):
108
  logger.info(f"Loading VIDEO model from: {VIDEO_MODEL_PATH}")
 
 
 
 
 
 
 
 
109
  _video_model = model_module.load_model_from_checkpoint(VIDEO_MODEL_PATH)
 
 
 
 
 
 
 
110
  _video_feature_extractor = model_module.build_video_feature_extractor()
111
- logger.info(f"Video model loaded successfully: name={getattr(_video_model, 'name', 'unknown')}")
 
 
 
 
 
112
  else:
113
  _video_model_error = f"Video model file not found at path: {VIDEO_MODEL_PATH}"
114
- logger.warning(_video_model_error)
115
  except Exception as e:
116
  _video_model = None
117
  _video_feature_extractor = None
@@ -129,30 +128,31 @@ class PredictResponse(BaseModel):
129
  def health():
130
  """
131
  Simple health check.
132
- Returns whether the model is loaded and exposes basic debug info.
133
  """
134
  return {
135
  "status": "ok",
136
- "model_loaded": _loaded_model is not None,
137
- "model_path": MODEL_PATH,
138
- "model_error": _model_load_error,
139
  "video_model_loaded": _video_model is not None,
140
  "video_model_path": VIDEO_MODEL_PATH,
141
  "video_model_error": _video_model_error,
 
 
142
  }
143
 
144
  @app.post("/predict", response_model=PredictResponse)
145
  async def predict(file: UploadFile = File(...)):
146
  """
147
- Accepts an uploaded image file, returns prediction.
 
148
  """
149
- global _loaded_model, _video_model, _video_feature_extractor
150
- if _loaded_model is None:
151
- # Do NOT silently fall back to an untrained model.
152
- # This would give meaningless predictions (often always one class).
153
- detail = _model_load_error or (
154
- "Model is not loaded. Ensure MODEL_PATH points to a valid trained model "
155
- "(.h5 file) and restart the API server."
156
  )
157
  raise HTTPException(status_code=500, detail=detail)
158
 
@@ -162,54 +162,50 @@ async def predict(file: UploadFile = File(...)):
162
 
163
  video_exts = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm"}
164
 
165
- # ----------------------- VIDEO INPUT -----------------------
166
- if ext in video_exts:
167
- if _video_model is None or _video_feature_extractor is None:
168
- detail = _video_model_error or (
169
- "Video model is not loaded. Ensure VIDEO_MODEL_PATH points to a valid "
170
- "trained video model (.h5) and that TensorFlow/OpenCV are installed."
171
- )
172
- raise HTTPException(status_code=500, detail=detail)
173
-
174
- contents = await file.read()
175
- tmp_path = None
176
- try:
177
- with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file:
178
- tmp_file.write(contents)
179
- tmp_path = tmp_file.name
180
-
181
- # Use unified helper that supports both images & videos
182
- result = model_module.predict_from_input_unified(
183
- _loaded_model,
184
- tmp_path,
185
- input_type="video",
186
- video_model=_video_model,
187
- feature_extractor=_video_feature_extractor,
188
- )
189
- return JSONResponse(content=result)
190
- except HTTPException:
191
- raise
192
- except Exception as e:
193
- raise HTTPException(status_code=400, detail=f"Could not process video: {e}")
194
- finally:
195
- if tmp_path and os.path.exists(tmp_path):
196
- try:
197
- os.unlink(tmp_path)
198
- except Exception:
199
- pass
200
-
201
- # ----------------------- IMAGE INPUT -----------------------
202
  contents = await file.read()
 
203
  try:
204
- from PIL import Image
205
-
206
- img = Image.open(io.BytesIO(contents)).convert("RGB")
207
- arr = np.asarray(img.resize((224, 224)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except Exception as e:
209
- raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
210
-
211
- result = model_module.predict_from_input(_loaded_model, arr)
212
- return JSONResponse(content=result)
 
 
 
213
 
214
  if __name__ == "__main__":
215
  # Run using the already imported `app` instance instead of a string path
 
68
  if not logger.handlers:
69
  logging.basicConfig(level=logging.INFO)
70
 
71
+ # IMAGE MODEL REMOVED - Only Video Model for Memory Optimization
72
+ # Image model loading disabled to save memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # VIDEO MODEL ONLY - Optimized Loading
75
  _raw_video_model_path = os.environ.get("VIDEO_MODEL_PATH", "video_deepfake_model.h5")
76
  if os.path.isabs(_raw_video_model_path):
77
  VIDEO_MODEL_PATH = _raw_video_model_path
 
85
  try:
86
  if os.path.exists(VIDEO_MODEL_PATH):
87
  logger.info(f"Loading VIDEO model from: {VIDEO_MODEL_PATH}")
88
+
89
+ # Memory optimization before loading
90
+ import gc
91
+ import tensorflow as tf
92
+ gc.collect()
93
+ tf.keras.backend.clear_session()
94
+
95
+ # Load video model with optimizations
96
  _video_model = model_module.load_model_from_checkpoint(VIDEO_MODEL_PATH)
97
+
98
+ # Optimize for inference
99
+ _video_model.trainable = False
100
+ for layer in _video_model.layers:
101
+ layer.trainable = False
102
+
103
+ # Build feature extractor
104
  _video_feature_extractor = model_module.build_video_feature_extractor()
105
+
106
+ logger.info(f"Video model loaded and optimized successfully!")
107
+
108
+ # Clear memory after loading
109
+ gc.collect()
110
+
111
  else:
112
  _video_model_error = f"Video model file not found at path: {VIDEO_MODEL_PATH}"
113
+ logger.error(_video_model_error)
114
  except Exception as e:
115
  _video_model = None
116
  _video_feature_extractor = None
 
128
  def health():
129
  """
130
  Simple health check.
131
+ Returns whether the video model is loaded and exposes basic debug info.
132
  """
133
  return {
134
  "status": "ok",
135
+ "image_model_enabled": False, # Disabled
 
 
136
  "video_model_loaded": _video_model is not None,
137
  "video_model_path": VIDEO_MODEL_PATH,
138
  "video_model_error": _video_model_error,
139
+ "memory_optimized": True,
140
+ "model_type": "video_only"
141
  }
142
 
143
  @app.post("/predict", response_model=PredictResponse)
144
  async def predict(file: UploadFile = File(...)):
145
  """
146
+ Accepts an uploaded video file, returns prediction.
147
+ ❌ Image processing disabled - Video only for memory optimization.
148
  """
149
+ global _video_model, _video_feature_extractor
150
+
151
+ # IMAGE PROCESSING DISABLED
152
+ if _video_model is None:
153
+ detail = _video_model_error or (
154
+ "Video model is not loaded. Ensure VIDEO_MODEL_PATH points to a valid "
155
+ "trained video model (.h5) and restart the API server."
156
  )
157
  raise HTTPException(status_code=500, detail=detail)
158
 
 
162
 
163
  video_exts = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm"}
164
 
165
+ # ONLY VIDEO PROCESSING ALLOWED
166
+ if ext not in video_exts:
167
+ raise HTTPException(
168
+ status_code=400,
169
+ detail=f" Image processing disabled. Please upload a video file. Supported formats: {', '.join(video_exts)}"
170
+ )
171
+
172
+ # ----------------------- VIDEO PROCESSING ONLY -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  contents = await file.read()
174
+ tmp_path = None
175
  try:
176
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file:
177
+ tmp_file.write(contents)
178
+ tmp_path = tmp_file.name
179
+
180
+ # Memory optimization before prediction
181
+ import gc
182
+ import tensorflow as tf
183
+ gc.collect()
184
+
185
+ # Use video model for prediction
186
+ result = model_module.predict_from_input_unified(
187
+ None, # No image model
188
+ tmp_path,
189
+ input_type="video",
190
+ video_model=_video_model,
191
+ feature_extractor=_video_feature_extractor,
192
+ )
193
+
194
+ # Clear memory after prediction
195
+ gc.collect()
196
+
197
+ return JSONResponse(content=result)
198
+
199
+ except HTTPException:
200
+ raise
201
  except Exception as e:
202
+ raise HTTPException(status_code=400, detail=f"Could not process video: {e}")
203
+ finally:
204
+ if tmp_path and os.path.exists(tmp_path):
205
+ try:
206
+ os.unlink(tmp_path)
207
+ except Exception:
208
+ pass
209
 
210
  if __name__ == "__main__":
211
  # Run using the already imported `app` instance instead of a string path