Hussein El-Hadidy commited on
Commit
d460d97
·
1 Parent(s): a9142c5

Deploy latest version to Hugging Face Space

Browse files
Files changed (1) hide show
  1. main.py +0 -547
main.py DELETED
@@ -1,547 +0,0 @@
1
- import os
2
- import pickle
3
- import shutil
4
- import uuid
5
- from fastapi import FastAPI, File, UploadFile
6
- from fastapi.responses import JSONResponse
7
- from pymongo.mongo_client import MongoClient
8
- from pymongo.server_api import ServerApi
9
- import cloudinary
10
- import cloudinary.uploader
11
- from cloudinary.utils import cloudinary_url
12
- from SkinBurns_Classification import extract_features
13
- from SkinBurns_Segmentation import segment_burn
14
- import requests
15
- import joblib
16
- import numpy as np
17
- from ECG.ECG_Classify import classify_ecg
18
- from ECG.ECG_MultiClass import analyze_ecg_pdf
19
- from ultralytics import YOLO
20
- import tensorflow as tf
21
- from fastapi import HTTPException
22
- from fastapi import WebSocket, WebSocketDisconnect
23
- import base64
24
- import cv2
25
- import time
26
- from CPR.CPRAnalyzer import CPRAnalyzer as OfflineAnalyzer
27
- import tempfile
28
- import matplotlib.pyplot as plt
29
- import json
30
- import asyncio
31
- import concurrent.futures
32
- from CPRRealTime.main import CPRAnalyzer as RealtimeAnalyzer
33
- from threading import Thread
34
- from starlette.responses import StreamingResponse
35
- import threading
36
- import queue
37
- from CPRRealTime.analysis_socket_server import AnalysisSocketServer # adjust if needed
38
- from CPRRealTime.logging_config import cpr_logger
39
- import logging
40
- import sys
41
- import re
42
- import signal
43
-
44
-
45
- app = FastAPI()
46
-
47
- SCREENSHOTS_DIR = "screenshots" # Folder containing screenshots to upload
48
- OUTPUT_DIR = "Output" # Folder containing the .mp4 video and graph .png
49
- UPLOAD_DIR = "uploads"
50
- os.makedirs(UPLOAD_DIR, exist_ok=True)
51
-
52
- # Load the YOLO model
53
- try:
54
- model = YOLO("yolo11n-pose_float16.tflite")
55
- print("Model loaded successfully")
56
- except Exception as e:
57
- print(f"❌ Model loading failed: {str(e)}")
58
- model = None
59
-
60
-
61
-
62
- # ✅ Cloudinary config
63
- cloudinary.config(
64
- cloud_name = "darumyfpl",
65
- api_key = "493972437417214",
66
- api_secret = "jjOScVGochJYA7IxDam7L4HU2Ig", # Replace in production
67
- secure=True
68
- )
69
-
70
- # Basic Hello route
71
- @app.get("/")
72
- def greet_json():
73
- return {"Hello": "World!"}
74
-
75
- @app.post("/predict_burn")
76
- async def predict_burn(file: UploadFile = File(...)):
77
- try:
78
- # Save the uploaded file temporarily
79
- temp_file_path = f"temp_{file.filename}"
80
- with open(temp_file_path, "wb") as temp_file:
81
- temp_file.write(await file.read())
82
-
83
-
84
- # Load the saved SVM model
85
- with open('svm_model.pkl', 'rb') as model_file:
86
- loaded_svm = pickle.load(model_file)
87
-
88
- # Extract features from the uploaded image
89
- features = extract_features(temp_file_path)
90
-
91
- # Remove the temporary file
92
- os.remove(temp_file_path)
93
-
94
- if features is None:
95
- return JSONResponse(content={"error": "Failed to extract features from the image."}, status_code=400)
96
-
97
- # Reshape features to match the SVM model's expected input
98
- features = features.reshape(1, -1)
99
-
100
- # Predict the class
101
- prediction = loaded_svm.predict(features)
102
- prediction_label = "Burn" if prediction[0] == 1 else "No Burn"
103
-
104
- if prediction[0] == 1:
105
- prediction_label = "First Class"
106
- elif prediction[0] == 2:
107
- prediction_label = "Second Class"
108
- else:
109
- prediction_label = "Zero Class"
110
-
111
- return {
112
- "prediction": prediction_label
113
- }
114
-
115
- except Exception as e:
116
- return JSONResponse(content={"error": str(e)}, status_code=500)
117
-
118
- @app.post("/segment_burn")
119
- async def segment_burn_endpoint(reference: UploadFile = File(...), patient: UploadFile = File(...)):
120
- try:
121
- # Save the reference image temporarily
122
- reference_path = f"temp_ref_{reference.filename}"
123
- reference_bytes = await reference.read()
124
- with open(reference_path, "wb") as ref_file:
125
- ref_file.write(reference_bytes)
126
-
127
- # Save the patient image temporarily
128
- patient_path = f"temp_patient_{patient.filename}"
129
- patient_bytes = await patient.read()
130
- with open(patient_path, "wb") as pat_file:
131
- pat_file.write(patient_bytes)
132
-
133
- # Call the segmentation logic
134
- burn_crop_clean, burn_crop_debug = segment_burn(patient_path, reference_path)
135
-
136
- # Save the cropped outputs
137
- burn_crop_clean_path = f"temp_burn_crop_clean_{uuid.uuid4()}.png"
138
- burn_crop_debug_path = f"temp_burn_crop_debug_{uuid.uuid4()}.png"
139
-
140
-
141
- plt.imsave(burn_crop_clean_path, burn_crop_clean)
142
- plt.imsave(burn_crop_debug_path, burn_crop_debug)
143
-
144
- # Upload to Cloudinary
145
- crop_clean_upload = cloudinary.uploader.upload(burn_crop_clean_path, public_id=f"ref_{reference.filename}")
146
- crop_debug_upload = cloudinary.uploader.upload(burn_crop_debug_path, public_id=f"pat_{patient.filename}")
147
- crop_clean_url = crop_clean_upload["secure_url"]
148
- crop_debug_url = crop_debug_upload["secure_url"]
149
-
150
- # Clean up temp files
151
-
152
- os.remove(burn_crop_clean_path)
153
- os.remove(burn_crop_debug_path)
154
-
155
-
156
- return {
157
- "crop_clean_url": crop_clean_url,
158
- "crop_debug_url": crop_debug_url
159
- }
160
-
161
- except Exception as e:
162
- return JSONResponse(content={"error": str(e)}, status_code=500)
163
-
164
-
165
- @app.post("/classify-ecg")
166
- async def classify_ecg_endpoint(file: UploadFile = File(...)):
167
- model = joblib.load('voting_classifier.pkl')
168
- # Load the model
169
-
170
- try:
171
- # Save the uploaded file temporarily
172
- temp_file_path = f"temp_{file.filename}"
173
- with open(temp_file_path, "wb") as temp_file:
174
- temp_file.write(await file.read())
175
-
176
- # Call the ECG classification function
177
- result = classify_ecg(temp_file_path, model, debug=True, is_pdf=True)
178
-
179
- # Remove the temporary file
180
- os.remove(temp_file_path)
181
-
182
- return {"result": result}
183
-
184
- except Exception as e:
185
- return JSONResponse(content={"error": str(e)}, status_code=500)
186
-
187
- @app.post("/diagnose-ecg")
188
- async def diagnose_ecg(file: UploadFile = File(...)):
189
- try:
190
- # Save the uploaded file temporarily
191
- temp_file_path = f"temp_{file.filename}"
192
- with open(temp_file_path, "wb") as temp_file:
193
- temp_file.write(await file.read())
194
-
195
- model_path = 'deep-multiclass.h5' # Update with actual path
196
- mlb_path = 'deep-multiclass.pkl' # Update with actual path
197
-
198
-
199
- # Call the ECG classification function
200
- result = analyze_ecg_pdf(
201
- temp_file_path,
202
- model_path,
203
- mlb_path,
204
- cleanup=False, # Keep the digitized file
205
- debug=False, # Print debug information
206
- visualize=False # Visualize the digitized signal
207
- )
208
-
209
-
210
- # Remove the temporary file
211
- os.remove(temp_file_path)
212
-
213
- if result and result["diagnosis"]:
214
- return {"result": result["diagnosis"]}
215
- else:
216
- return {"result": "No diagnosis"}
217
-
218
- except Exception as e:
219
- return JSONResponse(content={"error": str(e)}, status_code=500)
220
-
221
-
222
- def clean_warning_name(filename: str) -> str:
223
- """
224
- Remove frame index and underscores from filename base
225
- E.g. "posture_001.png" -> "posture"
226
- """
227
- name, _ = os.path.splitext(filename)
228
- # Remove trailing underscore + digits
229
- cleaned = re.sub(r'_\d+$', '', name)
230
- # Remove all underscores in the name for description
231
- cleaned_desc = cleaned.replace('_', ' ')
232
- return cleaned, cleaned_desc
233
-
234
- @app.post("/process_video")
235
- async def process_video(file: UploadFile = File(...)):
236
- if not file.content_type.startswith("video/"):
237
- raise HTTPException(status_code=400, detail="File must be a video.")
238
-
239
- print("File content type:", file.content_type)
240
- print("File filename:", file.filename)
241
-
242
- # Prepare directories
243
- os.makedirs(UPLOAD_DIR, exist_ok=True)
244
- os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
245
- os.makedirs(OUTPUT_DIR, exist_ok=True)
246
-
247
- folders = ["screenshots", "uploads", "Output"]
248
-
249
- for folder in folders:
250
- if os.path.exists(folder):
251
- for filename in os.listdir(folder):
252
- file_path = os.path.join(folder, filename)
253
- if os.path.isfile(file_path):
254
- os.remove(file_path)
255
-
256
- # Save uploaded video file
257
- video_path = os.path.join(UPLOAD_DIR, file.filename)
258
- with open(video_path, "wb") as buffer:
259
- shutil.copyfileobj(file.file, buffer)
260
-
261
- print(f"\n[API] CPR Analysis Started on {video_path}")
262
-
263
- # Prepare output paths for the analyzer
264
- video_output_path = os.path.join(OUTPUT_DIR, "Myoutput.mp4")
265
- plot_output_path = os.path.join(OUTPUT_DIR, "Myoutput.png")
266
-
267
- # Initialize analyzer with input video and output paths
268
- start_time = time.time()
269
- analyzer = OfflineAnalyzer(video_path, video_output_path, plot_output_path, requested_fps=30)
270
-
271
- # Run the analysis (choose your method)
272
- chunks = analyzer.run_analysis_video()
273
-
274
- warnings = [] # Start empty list
275
-
276
- # Upload screenshots and build warnings list with descriptions and URLs
277
- if os.path.exists(SCREENSHOTS_DIR):
278
- for filename in os.listdir(SCREENSHOTS_DIR):
279
- if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
280
- local_path = os.path.join(SCREENSHOTS_DIR, filename)
281
- cleaned_name, description = clean_warning_name(filename)
282
-
283
- upload_result = cloudinary.uploader.upload(
284
- local_path,
285
- folder="posture_warnings",
286
- public_id=cleaned_name,
287
- overwrite=True
288
- )
289
-
290
- # Add new warning with image_url and description
291
- warnings.append({
292
- "image_url": upload_result['secure_url'],
293
- "description": description
294
- })
295
-
296
- video_path = "Output/Myoutput_final.mp4"
297
-
298
- if os.path.isfile(video_path):
299
- upload_result = cloudinary.uploader.upload_large(
300
- video_path,
301
- resource_type="video",
302
- folder="output_videos",
303
- public_id="Myoutput_final",
304
- overwrite=True
305
- )
306
- wholevideoURL = upload_result['secure_url']
307
- else:
308
- wholevideoURL = None
309
-
310
- # Upload graph output
311
- graphURL = None
312
- if os.path.isfile(plot_output_path):
313
- upload_graph_result = cloudinary.uploader.upload(
314
- plot_output_path,
315
- folder="output_graphs",
316
- public_id=os.path.splitext(os.path.basename(plot_output_path))[0],
317
- overwrite=True
318
- )
319
- graphURL = upload_graph_result['secure_url']
320
-
321
- print(f"[API] CPR Analysis Completed on {video_path}")
322
- analysis_time = time.time() - start_time
323
- print(f"[TIMING] Analysis time: {analysis_time:.2f}s")
324
-
325
- if wholevideoURL is None:
326
- raise HTTPException(status_code=500, detail="No chunk data was generated from the video.")
327
-
328
- return JSONResponse(content={
329
- "videoURL": wholevideoURL,
330
- "graphURL": graphURL,
331
- "warnings": warnings,
332
- "chunks": chunks,
333
- })
334
-
335
-
336
- # @app.websocket("/ws/process_video")
337
- # async def websocket_process_video(websocket: WebSocket):
338
-
339
- # await websocket.accept()
340
-
341
- # frame_buffer = []
342
- # frame_limit = 50
343
- # frame_size = (640, 480) # Adjust if needed
344
- # fps = 30 # Adjust if needed
345
- # loop = asyncio.get_event_loop()
346
-
347
- # # Progress reporting during analysis
348
- # async def progress_callback(data):
349
- # await websocket.send_text(json.dumps(data))
350
-
351
- # def sync_callback(data):
352
- # asyncio.run_coroutine_threadsafe(progress_callback(data), loop)
353
-
354
- # def save_frames_to_video(frames, path):
355
- # out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), fps, frame_size)
356
- # for frame in frames:
357
- # resized = cv2.resize(frame, frame_size)
358
- # out.write(resized)
359
- # out.release()
360
-
361
- # def run_analysis_on_buffer(frames):
362
- # try:
363
- # tmp_path = "temp_video.mp4"
364
- # save_frames_to_video(frames, tmp_path)
365
-
366
- # # Notify: video saved
367
- # asyncio.run_coroutine_threadsafe(
368
- # websocket.send_text(json.dumps({
369
- # "status": "info",
370
- # "message": "Video saved. Starting CPR analysis..."
371
- # })),
372
- # loop
373
- # )
374
-
375
- # # Run analysis
376
- # analyzer = CPRAnalyzer(video_path=tmp_path)
377
- # analyzer.run_analysis(progress_callback=sync_callback)
378
-
379
- # except Exception as e:
380
- # asyncio.run_coroutine_threadsafe(
381
- # websocket.send_text(json.dumps({"error": str(e)})),
382
- # loop
383
- # )
384
-
385
- # try:
386
- # while True:
387
- # data: bytes = await websocket.receive_bytes()
388
- # np_arr = np.frombuffer(data, np.uint8)
389
- # frame = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
390
- # if frame is None:
391
- # continue
392
-
393
- # frame_buffer.append(frame)
394
- # print(f"Frame added to buffer: {len(frame_buffer)}")
395
-
396
- # if len(frame_buffer) == frame_limit:
397
- # # Notify Flutter that we're switching to processing
398
- # await websocket.send_text(json.dumps({
399
- # "status": "ready",
400
- # "message": "Prepare Right CPR: First 150 frames received. Starting processing."
401
- # }))
402
-
403
- # # Copy and clear buffer
404
- # buffer_copy = frame_buffer[:]
405
- # frame_buffer.clear()
406
-
407
- # # Launch background processing
408
- # executor = concurrent.futures.ThreadPoolExecutor()
409
- # loop.run_in_executor(executor, run_analysis_on_buffer, buffer_copy)
410
- # else:
411
- # # Tell Flutter to send the next frame
412
- # await websocket.send_text(json.dumps({
413
- # "status": "continue",
414
- # "message": f"Frame {len(frame_buffer)} received. Send next."
415
- # }))
416
-
417
- # except WebSocketDisconnect:
418
- # print("Client disconnected")
419
-
420
- # except Exception as e:
421
- # await websocket.send_text(json.dumps({"error": str(e)}))
422
-
423
- # finally:
424
- # cv2.destroyAllWindows()
425
-
426
-
427
- logger = logging.getLogger("cpr_logger")
428
- clients = set()
429
- analyzer_thread = None
430
- analysis_started = False
431
- analyzer_lock = threading.Lock()
432
- socket_server: AnalysisSocketServer = None # Global reference
433
-
434
-
435
- async def forward_results_from_queue(websocket: WebSocket, warning_queue):
436
- try:
437
- while True:
438
- warnings = await asyncio.to_thread(warning_queue.get)
439
- serialized = json.dumps(warnings)
440
- await websocket.send_text(serialized)
441
- except asyncio.CancelledError:
442
- logger.info("[WebSocket] Forwarding task cancelled")
443
- except Exception as e:
444
- logger.error(f"[WebSocket] Error forwarding data: {e}")
445
-
446
-
447
- def run_cpr_analysis(source, requested_fps, output_path):
448
- global socket_server
449
- logger.info(f"[MAIN] CPR Analysis Started")
450
-
451
- requested_fps = 30
452
- input_video = source
453
-
454
- output_dir = r"D:\BackendGp\Deploy_El7a2ny_Application\CPRRealTime\outputs"
455
- os.makedirs(output_dir, exist_ok=True)
456
-
457
- video_output_path = os.path.join(output_dir, "output.mp4")
458
- plot_output_path = os.path.join(output_dir, "output.png")
459
-
460
- logger.info(f"[CONFIG] Input video: {input_video}")
461
- logger.info(f"[CONFIG] Video output: {video_output_path}")
462
- logger.info(f"[CONFIG] Plot output: {plot_output_path}")
463
-
464
- initialization_start_time = time.time()
465
- analyzer = RealtimeAnalyzer(input_video, video_output_path, plot_output_path, requested_fps)
466
- socket_server = analyzer.socket_server
467
- analyzer.plot_output_path = plot_output_path
468
-
469
- elapsed_time = time.time() - initialization_start_time
470
- logger.info(f"[TIMING] Initialization time: {elapsed_time:.2f}s")
471
-
472
- try:
473
- analyzer.run_analysis()
474
- finally:
475
- if analyzer.socket_server:
476
- analyzer.socket_server.stop_server()
477
- logger.info("[MAIN] Analyzer stopped")
478
-
479
-
480
- @app.websocket("/ws/real")
481
- async def websocket_analysis(websocket: WebSocket):
482
- global analyzer_thread, analysis_started, socket_server
483
-
484
- await websocket.accept()
485
- clients.add(websocket)
486
- logger.info("[WebSocket] Flutter connected")
487
-
488
- try:
489
- # Wait for the client to send the stream URL as first message
490
- source = await websocket.receive_text()
491
- logger.info(f"[WebSocket] Received stream URL: {source}")
492
-
493
- # Ensure analyzer starts only once using a thread-safe lock
494
- with analyzer_lock:
495
- if not analysis_started:
496
- requested_fps = 30
497
- output_path = r"D:\CPR\End to End\Code Refactor\output\output.mp4"
498
-
499
- analyzer_thread = threading.Thread(
500
- target=run_cpr_analysis,
501
- args=(source, requested_fps, output_path),
502
- daemon=True
503
- )
504
- analyzer_thread.start()
505
- analysis_started = True
506
- logger.info("[WebSocket] Analysis thread started")
507
-
508
- # Rest of your existing code remains exactly the same...
509
- while socket_server is None or socket_server.warning_queue is None:
510
- await asyncio.sleep(0.1)
511
-
512
- forward_task = asyncio.create_task(
513
- forward_results_from_queue(websocket, socket_server.warning_queue)
514
- )
515
-
516
- while True:
517
- await asyncio.sleep(1) # Keep alive
518
-
519
- except WebSocketDisconnect:
520
- logger.warning("[WebSocket] Client disconnected")
521
- if 'forward_task' in locals():
522
- forward_task.cancel()
523
- except Exception as e:
524
- logger.error(f"[WebSocket] Error receiving stream URL: {str(e)}")
525
- await websocket.close(code=1011) # 1011 = Internal Error
526
- finally:
527
- clients.discard(websocket)
528
- logger.info(f"[WebSocket] Active clients: {len(clients)}")
529
-
530
- if not clients and socket_server:
531
- logger.info("[WebSocket] No clients left. Stopping analyzer.")
532
- socket_server.stop_server()
533
- analysis_started = False
534
- socket_server = None
535
-
536
-
537
- def shutdown_handler(signum, frame):
538
- logger.info("Received shutdown signal")
539
- if socket_server:
540
- try:
541
- socket_server.stop_server()
542
- except Exception as e:
543
- logger.warning(f"Error during socket server shutdown: {e}")
544
- os._exit(0)
545
-
546
- signal.signal(signal.SIGINT, shutdown_handler)
547
- signal.signal(signal.SIGTERM, shutdown_handler)