John Ho commited on
Commit
33a6de4
·
1 Parent(s): a31ba73

testing gh action to push to HF Space

Browse files
Files changed (4) hide show
  1. .github/workflows/deploy_to_hf_space.yaml +28 -15
  2. README.md +3 -0
  3. app.py +160 -166
  4. requirements.txt +9 -0
.github/workflows/deploy_to_hf_space.yaml CHANGED
@@ -30,7 +30,7 @@ jobs:
30
  - name: Set up Python
31
  uses: actions/setup-python@v5
32
  with:
33
- python-version: "3.12" # Recommended: specify a precise version like '3.10', '3.11', or '3.12'
34
 
35
  - name: Install uv
36
  # Installs the uv tool on the GitHub Actions runner
@@ -39,18 +39,21 @@ jobs:
39
  - name: Check for pyproject.toml existence
40
  id: check_pyproject
41
  run: |
42
- if [ -f pyproject.toml ]; then
43
- echo "::notice::pyproject.toml found. Proceeding with uv pip compile."
44
- echo "pyproject_exists=true" >> $GITHUB_OUTPUT
 
 
 
45
  else
46
- echo "::notice::pyproject.toml not found. Skipping requirements.txt generation via uv pip compile."
47
- echo "pyproject_exists=false" >> $GITHUB_OUTPUT
48
  fi
49
 
50
  - name: Generate requirements.txt using uv
51
  id: generate_reqs
52
  # This step will only run if pyproject.toml was found in the previous step
53
- if: ${{ steps.check_pyproject.outputs.pyproject_exists == 'true' }}
54
  run: |
55
  # Use uv pip compile to generate a locked requirements.txt from pyproject.toml
56
  # This ensures reproducibility.
@@ -66,15 +69,25 @@ jobs:
66
  exit 1
67
  fi
68
 
69
- # - name: Get ready to push to HuggingFace Space
70
- # # This step will only run if 'push_enabled' output from the previous step is 'true'
71
- # if: ${{ steps.check_hf_token.outputs.push_enabled == 'true' }}
72
- # uses: actions/checkout@v3
73
- # with:
74
- # fetch-depth: 0
75
- # lfs: true
 
 
76
  - name: Push to HuggingFace Space
77
  if: ${{ steps.check_hf_token.outputs.push_enabled == 'true' }}
78
  env:
79
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
80
- run: git push https://HF_USERNAME:$HF_TOKEN@huggingface.co/spaces/HF_USERNAME/SPACE_NAME main
 
 
 
 
 
 
 
 
 
30
  - name: Set up Python
31
  uses: actions/setup-python@v5
32
  with:
33
+ python-version: "3.10" # Recommended: specify a precise version like '3.10', '3.11', or '3.12'
34
 
35
  - name: Install uv
36
  # Installs the uv tool on the GitHub Actions runner
 
39
  - name: Check for pyproject.toml existence
40
  id: check_pyproject
41
  run: |
42
+ if [ -f requirements.txt ]; then
43
+ echo "::notice::requirements.txt already exists. Skipping uv generation."
44
+ echo "generate_reqs=false" >> $GITHUB_OUTPUT
45
+ elif [ -f pyproject.toml ]; then
46
+ echo "::notice::pyproject.toml found and no requirements.txt. Proceeding with uv pip compile."
47
+ echo "generate_reqs=true" >> $GITHUB_OUTPUT
48
  else
49
+ echo "::notice::Neither requirements.txt nor pyproject.toml found. Skipping uv pip compile."
50
+ echo "generate_reqs=false" >> $GITHUB_OUTPUT
51
  fi
52
 
53
  - name: Generate requirements.txt using uv
54
  id: generate_reqs
55
  # This step will only run if pyproject.toml was found in the previous step
56
+ if: ${{ steps.check_pyproject.outputs.generate_reqs == 'true' }}
57
  run: |
58
  # Use uv pip compile to generate a locked requirements.txt from pyproject.toml
59
  # This ensures reproducibility.
 
69
  exit 1
70
  fi
71
 
72
+ - name: Commit requirements.txt if changed
73
+ if: ${{ steps.check_pyproject.outputs.generate_reqs == 'true' }}
74
+ run: |
75
+ git config user.name "github-actions[bot]"
76
+ git config user.email "github-actions[bot]@users.noreply.github.com"
77
+ git add requirements.txt
78
+ git commit -m "chore: update requirements.txt [auto-generated by CI]"
79
+ echo "requirements.txt committed."
80
+
81
  - name: Push to HuggingFace Space
82
  if: ${{ steps.check_hf_token.outputs.push_enabled == 'true' }}
83
  env:
84
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
85
+ FORCE_PUSH: ${{ secrets.FORCE_PUSH }}
86
+ run: |
87
+ if [ -z "$FORCE_PUSH" ]; then
88
+ echo "::notice::FORCE_PUSH secret is not set."
89
+ git push https://GF-John:$HF_TOKEN@huggingface.co/spaces/GF-John/sam3 main
90
+ else
91
+ echo "::notice::FORCE_PUSH secret is set. Doing Force Push toHugging Face Space."
92
+ git push -f https://GF-John:$HF_TOKEN@huggingface.co/spaces/GF-John/sam3 main
93
+ fi
README.md CHANGED
@@ -10,6 +10,9 @@ pinned: false
10
  short_description: short description for your Space App
11
  ---
12
 
 
 
 
13
  # The HuggingFace Space Template
14
  setup with [github action to update automatically update your space](https://huggingface.co/docs/hub/spaces-github-actions)
15
  and manage dependencies with `uv`
 
10
  short_description: short description for your Space App
11
  ---
12
 
13
+ # SAM3 HuggingFace Space Demo
14
+ with inspiration from [prithivMLmods' demo](https://huggingface.co/spaces/prithivMLmods/SAM3-Demo), using the [transformers API](https://huggingface.co/docs/transformers/main/en/model_doc/sam3_video)
15
+
16
  # The HuggingFace Space Template
17
  setup with [github action to update automatically update your space](https://huggingface.co/docs/hub/spaces-github-actions)
18
  and manage dependencies with `uv`
app.py CHANGED
@@ -1,18 +1,22 @@
1
- import spaces, torch, time
 
 
 
2
  import gradio as gr
 
 
 
 
 
3
  from transformers import (
4
- AutoModelForImageTextToText,
5
- AutoProcessor,
6
- BitsAndBytesConfig,
7
  )
8
 
9
- # Flash Attention for ZeroGPU
10
- import subprocess
11
-
12
- subprocess.run(
13
- "pip install flash-attn --no-build-isolation",
14
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
15
- shell=True,
16
  )
17
 
18
  # Set target DEVICE and DTYPE
@@ -22,153 +26,166 @@ DTYPE = (
22
  else torch.float16
23
  )
24
  DEVICE = "auto"
25
- print(f"Device: {DEVICE}, dtype: {DTYPE}")
26
-
27
-
28
- def load_model(
29
- model_name: str = "chancharikm/qwen2.5-vl-7b-cam-motion-preview",
30
- use_flash_attention: bool = True,
31
- apply_quantization: bool = True,
32
- ):
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_4bit=True, # Load model weights in 4-bit
35
- bnb_4bit_quant_type="nf4", # Use NF4 quantization (or "fp4")
36
- bnb_4bit_compute_dtype=DTYPE, # Perform computations in bfloat16/float16
37
- bnb_4bit_use_double_quant=True, # Optional: further quantization for slightly more memory saving
38
- )
39
-
40
- # Determine model family from model name
41
- model_family = model_name.split("/")[-1].split("-")[0]
42
-
43
- # Common model loading arguments
44
- common_args = {
45
- "torch_dtype": DTYPE,
46
- "device_map": DEVICE,
47
- "low_cpu_mem_usage": True,
48
- "quantization_config": bnb_config if apply_quantization else None,
49
- }
50
- if use_flash_attention:
51
- common_args["attn_implementation"] = "flash_attention_2"
52
-
53
- # Load model based on family
54
- match model_family:
55
- # case "qwen2.5" | "Qwen2.5":
56
- # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
- # model_name, **common_args
58
- # )
59
- case "InternVL3":
60
- model = AutoModelForImageTextToText.from_pretrained(
61
- model_name, **common_args
 
 
 
 
 
 
62
  )
63
- case _:
64
- raise ValueError(f"Unsupported model family: {model_family}")
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Set model to evaluation mode for inference (disables dropout, etc.)
67
- return model.eval()
68
 
 
 
 
 
69
 
70
- def load_processor(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
71
- return AutoProcessor.from_pretrained(
72
- model_name,
73
- device_map=DEVICE,
74
- use_fast=True,
75
- torch_dtype=DTYPE,
76
- )
77
 
78
 
79
  print("Loading Models and Processors...")
80
- MODEL_ZOO = {
81
- "qwen2.5-vl-7b-instruct": load_model(
82
- model_name="Qwen/Qwen2.5-VL-7B-Instruct",
83
- use_flash_attention=False,
84
- apply_quantization=False,
85
- ),
86
- "InternVL3-1B-hf": load_model(
87
- model_name="OpenGVLab/InternVL3-1B-hf",
88
- use_flash_attention=False,
89
- apply_quantization=False,
90
- ),
91
- "InternVL3-2B-hf": load_model(
92
- model_name="OpenGVLab/InternVL3-2B-hf",
93
- use_flash_attention=False,
94
- apply_quantization=False,
95
- ),
96
- "InternVL3-8B-hf": load_model(
97
- model_name="OpenGVLab/InternVL3-8B-hf",
98
- use_flash_attention=False,
99
- apply_quantization=True,
100
- ),
101
- }
102
-
103
- PROCESSORS = {
104
- "qwen2.5-vl-7b-instruct": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"),
105
- "InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
106
- "InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"),
107
- "InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"),
108
- }
109
- print("Models and Processors Loaded!")
110
 
111
 
112
  # Our Inference Function
113
  @spaces.GPU(duration=120)
114
- def video_inference(
115
- video_path: str,
116
- prompt: str,
117
- model_name: str,
118
- fps: int = 8,
119
- max_tokens: int = 512,
120
- temperature: float = 0.1,
121
- ):
122
- s_time = time.time()
123
- model = MODEL_ZOO[model_name]
124
- processor = PROCESSORS[model_name]
125
- messages = [
126
- {
127
- "role": "user",
128
- "content": [
129
- {
130
- "type": "video",
131
- "video": video_path,
132
- },
133
- {"type": "text", "text": prompt},
134
- ],
135
  }
136
- ]
137
- with torch.no_grad():
138
- model_family = model_name.split("-")[0]
139
- match model_family:
140
- case "InternVL3":
141
- inputs = processor.apply_chat_template(
142
- messages,
143
- add_generation_prompt=True,
144
- tokenize=True,
145
- return_dict=True,
146
- return_tensors="pt",
147
- fps=fps,
148
- # num_frames = 8
149
- ).to("cuda", dtype=DTYPE)
150
-
151
- output = model.generate(
152
- **inputs,
153
- max_new_tokens=max_tokens,
154
- temperature=float(temperature),
155
- do_sample=temperature > 0.0,
156
- )
157
- output_text = processor.decode(
158
- output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
- case _:
161
- raise ValueError(f"{model_name} is not currently supported")
162
- return {
163
- "output_text": output_text,
164
- "fps": fps,
165
- "inference_time": time.time() - s_time,
166
- }
 
 
 
 
 
 
167
 
168
 
169
  # the Gradio App
170
  app = gr.Interface(
171
- fn=inference,
172
  inputs=[
173
  gr.Video(label="Input Video"),
174
  gr.Textbox(
@@ -177,33 +194,10 @@ app = gr.Interface(
177
  info="Some models like [cam motion](https://huggingface.co/chancharikm/qwen2.5-vl-7b-cam-motion-preview) are trained specific prompts",
178
  value="Describe the camera motion in this video.",
179
  ),
180
- gr.Dropdown(label="Model", choices=list(MODEL_ZOO.keys())),
181
- gr.Number(
182
- label="FPS",
183
- info="inference sampling rate (Qwen2.5VL is trained on videos with 8 fps); a value of 0 means the FPS of the input video will be used",
184
- value=8,
185
- minimum=0,
186
- step=1,
187
- ),
188
- gr.Slider(
189
- label="Max Tokens",
190
- info="maximum number of tokens to generate",
191
- value=128,
192
- minimum=32,
193
- maximum=512,
194
- step=32,
195
- ),
196
- gr.Slider(
197
- label="Temperature",
198
- value=0.0,
199
- minimum=0.0,
200
- maximum=1.0,
201
- step=0.1,
202
- ),
203
  ],
204
  outputs=gr.JSON(label="Output JSON"),
205
- title="Video Chat with VLM",
206
- description='comparing various "small" VLMs on the task of video captioning',
207
  api_name="video_inference",
208
  )
209
  app.launch(
 
1
+ import tempfile
2
+ import time
3
+
4
+ import cv2
5
  import gradio as gr
6
+ import matplotlib
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ from PIL import Image
11
  from transformers import (
12
+ Sam3VideoModel,
13
+ Sam3VideoProcessor,
 
14
  )
15
 
16
+ logger.remove()
17
+ logger.add(
18
+ sys.stderr,
19
+ format="<d>{time:YYYY-MM-DD ddd HH:mm:ss}</d> | <lvl>{level}</lvl> | <lvl>{message}</lvl>",
 
 
 
20
  )
21
 
22
  # Set target DEVICE and DTYPE
 
26
  else torch.float16
27
  )
28
  DEVICE = "auto"
29
+ logger.info(f"Device: {DEVICE}, dtype: {DTYPE}")
30
+
31
+
32
+ def apply_mask_overlay(base_image, mask_data, object_ids=None, opacity=0.5):
33
+ """Draws segmentation masks on top of an image, using object IDs for coloring."""
34
+ if isinstance(base_image, np.ndarray):
35
+ base_image = Image.fromarray(base_image)
36
+ base_image = base_image.convert("RGBA")
37
+
38
+ if mask_data is None or len(mask_data) == 0:
39
+ return base_image.convert("RGB")
40
+
41
+ if isinstance(mask_data, torch.Tensor):
42
+ mask_data = mask_data.cpu().numpy()
43
+ mask_data = mask_data.astype(np.uint8)
44
+
45
+ # Handle dimensions
46
+ if mask_data.ndim == 4:
47
+ mask_data = mask_data[0]
48
+ if mask_data.ndim == 3 and mask_data.shape[0] == 1:
49
+ mask_data = mask_data[0]
50
+
51
+ num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
52
+ if mask_data.ndim == 2:
53
+ mask_data = [mask_data]
54
+ num_masks = 1
55
+
56
+ # Use object_ids for coloring if provided, else fallback to index
57
+ if object_ids is not None and len(object_ids) == num_masks:
58
+ # Use a fixed color map and assign color based on object_id
59
+ try:
60
+ color_map = matplotlib.colormaps["rainbow"]
61
+ except AttributeError:
62
+ import matplotlib.cm as cm
63
+
64
+ color_map = cm.get_cmap("rainbow")
65
+ # Normalize object_ids to a color index (e.g., mod by 256)
66
+ unique_ids = sorted(set(object_ids))
67
+ id_to_color_idx = {oid: i for i, oid in enumerate(unique_ids)}
68
+ rgb_colors = [
69
+ tuple(
70
+ int(c * 255)
71
+ for c in color_map(id_to_color_idx[oid] / max(len(unique_ids), 1))[:3]
72
  )
73
+ for oid in object_ids
74
+ ]
75
+ else:
76
+ try:
77
+ color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
78
+ except AttributeError:
79
+ import matplotlib.cm as cm
80
+
81
+ color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
82
+ rgb_colors = [
83
+ tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)
84
+ ]
85
 
86
+ composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
 
87
 
88
+ for i, single_mask in enumerate(mask_data):
89
+ mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8))
90
+ if mask_bitmap.size != base_image.size:
91
+ mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
92
 
93
+ fill_color = rgb_colors[i]
94
+ color_fill = Image.new("RGBA", base_image.size, fill_color + (0,))
95
+ mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
96
+ color_fill.putalpha(mask_alpha)
97
+ composite_layer = Image.alpha_composite(composite_layer, color_fill)
98
+
99
+ return Image.alpha_composite(base_image, composite_layer).convert("RGB")
100
 
101
 
102
  print("Loading Models and Processors...")
103
+ try:
104
+ VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(DEVICE, dtype=DTYPE)
105
+ VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
106
+ logger.success("Models and Processors Loaded!")
107
+ except Exception as e:
108
+ logger.error(f"❌ CRITICAL ERROR LOADING VIDEO MODELS: {e}")
109
+ VID_MODEL = None
110
+ VID_PROCESSOR = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  # Our Inference Function
114
  @spaces.GPU(duration=120)
115
+ def video_inference(input_video, prompt):
116
+ """
117
+ Segments objects in a video using a text prompt.
118
+ Returns a JSON with output video path and status.
119
+ """
120
+ if VID_MODEL is None or VID_PROCESSOR is None:
121
+ return {
122
+ "output_video": None,
123
+ "status": "Video Models failed to load on startup.",
 
 
 
 
 
 
 
 
 
 
 
 
124
  }
125
+ if input_video is None or not prompt:
126
+ return {"output_video": None, "status": "Missing video or prompt."}
127
+ try:
128
+ # Gradio passes a dict with 'name' key for uploaded files
129
+ video_path = (
130
+ input_video
131
+ if isinstance(input_video, str)
132
+ else input_video.get("name", None)
133
+ )
134
+ if not video_path:
135
+ return {"output_video": None, "status": "Invalid video input."}
136
+ video_cap = cv2.VideoCapture(video_path)
137
+ vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
138
+ vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
139
+ vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
140
+ video_frames = []
141
+ while video_cap.isOpened():
142
+ ret, frame = video_cap.read()
143
+ if not ret:
144
+ break
145
+ video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
146
+ video_cap.release()
147
+ if len(video_frames) == 0:
148
+ return {"output_video": None, "status": "No frames found in video."}
149
+ session = VID_PROCESSOR.init_video_session(
150
+ video=video_frames, inference_device=DEVICE, dtype=DTYPE
151
+ )
152
+ session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
153
+ temp_out_path = tempfile.mktemp(suffix=".mp4")
154
+ video_writer = cv2.VideoWriter(
155
+ temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
156
+ )
157
+ for model_out in VID_MODEL.propagate_in_video_iterator(
158
+ inference_session=session, max_frame_num_to_track=len(video_frames)
159
+ ):
160
+ post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
161
+ f_idx = model_out.frame_idx
162
+ original_pil = Image.fromarray(video_frames[f_idx])
163
+ if "masks" in post_processed:
164
+ detected_masks = post_processed["masks"]
165
+ object_ids = post_processed["object_ids"]
166
+ if detected_masks.ndim == 4:
167
+ detected_masks = detected_masks.squeeze(1)
168
+ final_frame = apply_mask_overlay(
169
+ original_pil, detected_masks, object_ids=object_ids
170
  )
171
+ else:
172
+ final_frame = original_pil
173
+ video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
174
+ video_writer.release()
175
+ return {
176
+ "output_video": temp_out_path,
177
+ "status": "Video processing completed successfully.✅",
178
+ }
179
+ except Exception as e:
180
+ return {
181
+ "output_video": None,
182
+ "status": f"Error during video processing: {str(e)}",
183
+ }
184
 
185
 
186
  # the Gradio App
187
  app = gr.Interface(
188
+ fn=video_inference,
189
  inputs=[
190
  gr.Video(label="Input Video"),
191
  gr.Textbox(
 
194
  info="Some models like [cam motion](https://huggingface.co/chancharikm/qwen2.5-vl-7b-cam-motion-preview) are trained specific prompts",
195
  value="Describe the camera motion in this video.",
196
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  ],
198
  outputs=gr.JSON(label="Output JSON"),
199
+ title="SAM3 Video Segmentation",
200
+ description="Segment Objects in Video using Text Prompts",
201
  api_name="video_inference",
202
  )
203
  app.launch(
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ accelerate
5
+ loguru
6
+ opencv-python-headless>=4.11.0.86
7
+ peft
8
+ sentencepiece
9
+ matplotlib