mlbench123 commited on
Commit
e30e308
Β·
verified Β·
1 Parent(s): b58eb15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -80
app.py CHANGED
@@ -1,71 +1,77 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
- import cv2 # OpenCV for video processing (if used)
 
 
 
5
  from extract_frames import video_to_keyframes
6
  from apply_mask import apply_mask_and_crop
7
  from run_gmm import run_gmm_inference
8
  from compose_video import compose_final_video
9
- # import the processing functions from original app
10
- # from heatmap_module import video_to_keyframes, apply_mask_and_crop, run_gmm_inference, compose_final_video
11
 
12
- # Helper to extract first frame for mask drawing
 
 
 
 
 
 
 
 
 
 
 
 
13
  def get_first_frame(video_path):
14
  cap = cv2.VideoCapture(video_path)
15
  success, frame = cap.read()
16
  cap.release()
17
  if success:
18
- # Convert BGR to RGB color for PIL/Gradio
19
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
20
  return Image.fromarray(frame)
21
- else:
22
- return None
23
-
24
- # Helper to get mask from drawn data
25
- def extract_mask_from_drawn(composite_image, background_image):
26
- # Convert to numpy arrays for comparison
27
- comp = np.array(composite_image)
28
- bg = np.array(background_image)
29
- if comp.shape != bg.shape:
30
- # If background not same shape as composite, just threshold comp
31
- gray = comp if comp.ndim == 2 else comp[..., :3].mean(axis=-1)
32
- mask = (gray > 10).astype(np.uint8) # simple threshold
33
- else:
34
- # Compute difference where composite != background (assuming draw color != background)
35
- diff = np.any(comp != bg, axis=-1)
36
- mask = diff.astype(np.uint8)
37
- return mask * 255 # return as binary mask image (255 inside mask)
38
-
39
- def process_video(video_file, mask_image, drawn_editor, progress=gr.Progress()):
40
- import os
41
- import uuid
42
-
43
- # Prepare all output folders
44
  base_dir = "video_outputs"
45
  extracted_dir = os.path.join(base_dir, "extracted_frames")
46
  masked_dir = os.path.join(base_dir, "masked_frames")
47
  heatmap_dir = os.path.join(base_dir, "output_heatmap")
 
 
 
 
 
 
 
 
 
 
48
 
49
- os.makedirs("video_outputs/extracted_frames", exist_ok=True)
50
- os.makedirs("video_outputs/masked_frames", exist_ok=True)
51
- os.makedirs("video_outputs/output_heatmap", exist_ok=True)
52
- os.makedirs("video_inputs", exist_ok=True)
53
- os.makedirs("assets", exist_ok=True)
54
-
55
- # Choose mask: from upload or drawing
56
- if mask_image is not None:
57
- mask = mask_image
58
- if mask.ndim == 3:
59
- mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
60
- _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
61
- elif drawn_editor and "composite" in drawn_editor and "background" in drawn_editor:
62
- mask = extract_mask_from_drawn(drawn_editor)
63
- else:
64
- raise gr.Error("Please provide a valid mask (uploaded or drawn).")
65
 
66
  progress(0, desc="Extracting keyframes...")
67
  video_to_keyframes(video_file, extracted_dir)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  progress(0.3, desc="Applying mask and cropping...")
70
  apply_mask_and_crop(extracted_dir, mask, masked_dir)
71
 
@@ -77,53 +83,58 @@ def process_video(video_file, mask_image, drawn_editor, progress=gr.Progress()):
77
  result_path = os.path.join(base_dir, video_name)
78
  compose_final_video(mask, heatmap_dir, extracted_dir, result_path)
79
 
80
- progress(1.0, desc="Done!")
 
81
 
82
- return "βœ… Heatmap video generated!", result_path, result_path
83
 
84
-
85
- # Define the Gradio app layout
86
  custom_css = """
87
- .gradio-container {background: url('/gradio_api/file=background.jpg') center/cover no-repeat !important;
88
- background-color: #000 !important;}
89
- .panel {max-width: 800px; margin: 2rem auto; padding: 2rem; background: rgba(30,30,30, 0.8); border-radius: 8px;}
 
 
 
 
 
 
 
 
90
  """
91
- with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css, title="Heatmap Generator") as demo:
92
- gr.Markdown("## πŸŽ₯ Heatmap Generator", elem_classes="panel")
 
 
93
  with gr.Row(elem_classes="panel"):
94
  video_input = gr.Video(label="Upload Video", format="mp4")
95
- with gr.Tabs(elem_classes="panel"):
96
- with gr.Tab("Upload Mask"):
97
- mask_upload = gr.Image(label="Upload Mask Image", type="numpy")
98
- with gr.Tab("Draw Mask"):
99
- draw_info = gr.Markdown("*Draw mask on the frame:* Use brush to highlight the region of interest.")
100
- mask_draw = gr.ImageEditor(label="Draw Mask", type="pil")
101
- # Buttons
102
  with gr.Row(elem_classes="panel"):
103
  generate_btn = gr.Button("πŸ”₯ Generate Heatmap", variant="primary")
104
  reset_btn = gr.Button("Reset")
105
- download_btn = gr.File(label="Download Video")
106
- # Status and output
107
  with gr.Row(elem_classes="panel"):
108
- status_text = gr.Markdown("") # to show status or final message
 
109
  with gr.Row(elem_classes="panel"):
110
  output_video = gr.Video(label="Output Video")
111
- # Event handlers
112
- # When video is uploaded, extract a frame and set it in the draw component
113
- def prep_frame_for_drawing(video_file):
114
- if video_file is None:
115
- return None
116
- frame = get_first_frame(video_file)
117
- return {'background': frame, 'composite': frame} # initial EditorValue
118
- video_input.change(fn=prep_frame_for_drawing, inputs=video_input, outputs=mask_draw)
119
- # Generate button triggers processing
120
- generate_btn.click(fn=process_video, inputs=[video_input, mask_upload, mask_draw], outputs=[status_text, output_video])
121
- # After video is generated, enable download (bind the file path from output)
122
- # (Gradio may automatically handle download if output_video has a file source)
123
- generate_btn.click(fn=lambda vid: vid, inputs=output_video, outputs=download_btn)
124
- # Reset button clears all
125
- reset_btn.click(fn=lambda: (None, None, None, "", None), inputs=[],
126
- outputs=[video_input, mask_upload, mask_draw, status_text, output_video])
127
- # Launch (if running locally; on HF Spaces this is handled automatically)
 
128
  if __name__ == "__main__":
129
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
+ import cv2
5
+ import os
6
+ import uuid
7
+
8
  from extract_frames import video_to_keyframes
9
  from apply_mask import apply_mask_and_crop
10
  from run_gmm import run_gmm_inference
11
  from compose_video import compose_final_video
 
 
12
 
13
+
14
+ # Ensure folders exist
15
+ for path in [
16
+ "video_outputs/extracted_frames",
17
+ "video_outputs/masked_frames",
18
+ "video_outputs/output_heatmap",
19
+ "video_inputs",
20
+ "assets"
21
+ ]:
22
+ os.makedirs(path, exist_ok=True)
23
+
24
+
25
+ # Get first frame for preview
26
  def get_first_frame(video_path):
27
  cap = cv2.VideoCapture(video_path)
28
  success, frame = cap.read()
29
  cap.release()
30
  if success:
 
31
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32
  return Image.fromarray(frame)
33
+ return None
34
+
35
+
36
+ def process_video(video_file, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  base_dir = "video_outputs"
38
  extracted_dir = os.path.join(base_dir, "extracted_frames")
39
  masked_dir = os.path.join(base_dir, "masked_frames")
40
  heatmap_dir = os.path.join(base_dir, "output_heatmap")
41
+
42
+ # Clear old frames
43
+ for folder in [extracted_dir, masked_dir, heatmap_dir]:
44
+ for f in os.listdir(folder):
45
+ os.remove(os.path.join(folder, f))
46
+
47
+ # Load default mask
48
+ mask_path = "assets/default_mask.png"
49
+ if not os.path.exists(mask_path):
50
+ raise gr.Error("❌ Default mask not found at 'assets/default_mask.png'")
51
 
52
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
53
+ if mask is None:
54
+ raise gr.Error("❌ Failed to load default mask.")
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  progress(0, desc="Extracting keyframes...")
57
  video_to_keyframes(video_file, extracted_dir)
58
 
59
+ # Load first frame to align mask size
60
+ first_frame_name = sorted(os.listdir(extracted_dir))[0]
61
+ first_frame = cv2.imread(os.path.join(extracted_dir, first_frame_name))
62
+
63
+ if first_frame is None:
64
+ raise gr.Error("❌ Failed to read first extracted keyframe.")
65
+
66
+ if mask.shape != first_frame.shape[:2]:
67
+ mask = cv2.resize(mask, (first_frame.shape[1], first_frame.shape[0]))
68
+
69
+ # Optional: get bounding box (coords) of table region
70
+ coords = cv2.findNonZero(mask)
71
+ if coords is None:
72
+ raise gr.Error("❌ No table region detected in default mask.")
73
+ x, y, w, h = cv2.boundingRect(coords)
74
+
75
  progress(0.3, desc="Applying mask and cropping...")
76
  apply_mask_and_crop(extracted_dir, mask, masked_dir)
77
 
 
83
  result_path = os.path.join(base_dir, video_name)
84
  compose_final_video(mask, heatmap_dir, extracted_dir, result_path)
85
 
86
+ progress(1.0, desc="Done βœ…")
87
+ return "βœ… Heatmap video generated successfully!", result_path, result_path
88
 
 
89
 
90
+ # Layout
 
91
  custom_css = """
92
+ .gradio-container {
93
+ background: url('/gradio_api/file=background.jpg') center/cover no-repeat !important;
94
+ background-color: #000 !important;
95
+ }
96
+ .panel {
97
+ max-width: 800px;
98
+ margin: 2rem auto;
99
+ padding: 2rem;
100
+ background: rgba(30,30,30, 0.8);
101
+ border-radius: 8px;
102
+ }
103
  """
104
+
105
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css, title="UV Scan - Table Heatmap") as demo:
106
+ gr.Markdown("## πŸŽ₯ UV Scan – Table Heatmap Generator", elem_classes="panel")
107
+
108
  with gr.Row(elem_classes="panel"):
109
  video_input = gr.Video(label="Upload Video", format="mp4")
110
+
 
 
 
 
 
 
111
  with gr.Row(elem_classes="panel"):
112
  generate_btn = gr.Button("πŸ”₯ Generate Heatmap", variant="primary")
113
  reset_btn = gr.Button("Reset")
114
+ download_btn = gr.File(label="⬇️ Download Video")
115
+
116
  with gr.Row(elem_classes="panel"):
117
+ status_text = gr.Markdown("")
118
+
119
  with gr.Row(elem_classes="panel"):
120
  output_video = gr.Video(label="Output Video")
121
+
122
+ def on_video_upload(video_file):
123
+ return get_first_frame(video_file)
124
+
125
+ video_input.change(fn=on_video_upload, inputs=video_input, outputs=None)
126
+
127
+ generate_btn.click(
128
+ fn=process_video,
129
+ inputs=[video_input],
130
+ outputs=[status_text, output_video, download_btn]
131
+ )
132
+
133
+ reset_btn.click(
134
+ fn=lambda: (None, "", None, None),
135
+ inputs=[],
136
+ outputs=[video_input, status_text, output_video, download_btn]
137
+ )
138
+
139
  if __name__ == "__main__":
140
  demo.launch()