[Admin maintenance] Migrate grant to ZeroGPU

#1
by multimodalart HF Staff - opened
Files changed (2) hide show
  1. app.py +72 -156
  2. requirements.txt +7 -6
app.py CHANGED
@@ -1,111 +1,105 @@
 
1
  import gradio as gr
2
- import random
3
  import torch
4
  import numpy as np
5
- from PIL import Image, ImageOps
6
  import os
7
  import json
8
- import sys
9
- import multiprocessing
10
- from concurrent.futures import ProcessPoolExecutor
11
- import time
12
 
13
- # Assume MagicQuill and other dependencies are present as per user instruction
14
  from MagicQuill import folder_paths
15
  from MagicQuill.llava_new import LLaVAModel
16
  from huggingface_hub import snapshot_download
17
 
18
- # Imports for SAM (Only needed in worker process, but imported here for checking)
19
  from segment_anything import sam_model_registry, SamPredictor
20
 
21
- # Download models (Main process does this once)
22
  hf_token = os.environ.get("HF_TOKEN")
23
  snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
24
  snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
25
 
26
- # --- Global Models for Main Process ---
27
- print("Initializing LLaVAModel (Main Process)...")
28
- # LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads)
29
  llavaModel = LLaVAModel()
30
  print("LLaVAModel initialized.")
31
 
32
- # --- Worker Process Logic for SAM ---
33
- # Global variable for the worker process to hold its own SAM instance
34
- worker_sam = None
35
-
36
- def init_worker_sam(device='cuda'):
37
- """
38
- This function is called when a new worker process starts.
39
- It initializes a standalone SAM model for that process.
40
- """
41
- global worker_sam
42
- print(f"Process {os.getpid()}: Initializing SAM model...")
43
-
44
- # Define SAM class locally or import it. Since it was defined in the script,
45
- # we can redefine a helper or import the logic.
46
- # Ideally, the SAM logic should be in a separate module to be picklable easily.
47
- # But for this script, we can define the loading logic here.
48
-
49
- checkpoint_path = 'models_v2/sam/sam_vit_b_01ec64.pth'
50
-
51
- # Load Model
52
- try:
53
- sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path)
54
- sam.to(device=device)
55
- predictor = SamPredictor(sam)
56
-
57
- worker_sam = {
58
- "predictor": predictor
59
- }
60
- print(f"Process {os.getpid()}: SAM initialized.")
61
- except Exception as e:
62
- print(f"Process {os.getpid()}: Failed to init SAM: {e}")
63
-
64
- def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes):
65
- """
66
- The actual inference function running inside the worker process.
67
- """
68
- global worker_sam
69
-
70
- if worker_sam is None:
71
- # Fallback if init didn't run or failed (though ProcessPool initializer should handle it)
72
- init_worker_sam()
73
-
74
- predictor = worker_sam["predictor"]
75
-
76
- # Set Image
77
- predictor.set_image(image_np)
78
-
 
 
 
 
 
79
  input_point = []
80
  input_label = []
81
-
82
- # Process points
83
  if coordinates_positive:
84
  coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
85
  for p in coords:
86
  input_point.append([p['x'], p['y']])
87
  input_label.append(1)
88
-
89
  if coordinates_negative:
90
  coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
91
  for p in coords:
92
  input_point.append([p['x'], p['y']])
93
  input_label.append(0)
94
 
95
- # Process bbox
96
  input_box = None
97
  if bboxes:
98
  if isinstance(bboxes, str):
99
  try:
100
  bboxes = json.loads(bboxes)
101
- except:
102
  pass
103
-
104
  box_list = []
105
  if isinstance(bboxes, list):
106
  for box in bboxes:
107
  box_list.append(list(box))
108
-
109
  if len(box_list) > 0:
110
  input_box = np.array(box_list)
111
 
@@ -116,104 +110,35 @@ def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bbox
116
  input_point = None
117
  input_label = None
118
 
119
- # Predict
120
- masks, scores, logits = predictor.predict(
121
  point_coords=input_point,
122
  point_labels=input_label,
123
  box=input_box,
124
  multimask_output=False,
125
  )
126
-
127
  mask_np = masks[0]
128
-
129
- # Post-processing
130
- # Simply convert mask to uint8 [0, 255] for transport
131
  if mask_np.dtype == bool:
132
  mask_np = mask_np.astype(np.uint8) * 255
133
  else:
134
  mask_np = (mask_np > 0).astype(np.uint8) * 255
135
-
136
- # Return mask as image for client to use
137
- # We return mask_np twice to satisfy the function signature or unpacker in segment()
138
- # segment() expects (image_with_alpha_np, mask_np)
139
- return mask_np, mask_np
140
 
 
141
 
142
- # --- Main Process Helpers ---
143
-
144
- # We need a pool. Since we are in a script, we initialize it in main block.
145
- sam_pool = None
146
-
147
- def numpy_to_tensor(numpy_array):
148
- tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
149
- return tensor
150
-
151
- def guess(original_image, add_color_image, add_edge_mask):
152
- # LLaVA inference runs in the main process (threaded)
153
- original_image_tensor = numpy_to_tensor(original_image)
154
- add_color_image_tensor = numpy_to_tensor(add_color_image)
155
- add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
156
-
157
- description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
158
-
159
- ans_list = []
160
- if ans1 and ans1 != "":
161
- ans_list.append(ans1)
162
- if ans2 and ans2 != "":
163
- ans_list.append(ans2)
164
-
165
- return ", ".join(ans_list)
166
-
167
- def get_mask_bbox(mask_np):
168
- # mask_np: [1, H, W] or [H, W]
169
- if mask_np.ndim == 3:
170
- mask_np = mask_np[0]
171
-
172
- rows = np.any(mask_np, axis=1)
173
- cols = np.any(mask_np, axis=0)
174
- if not np.any(rows) or not np.any(cols):
175
- return None
176
-
177
- y_min, y_max = np.where(rows)[0][[0, -1]]
178
- x_min, x_max = np.where(cols)[0][[0, -1]]
179
- return int(x_min), int(y_min), int(x_max), int(y_max)
180
-
181
- def segment(image, coordinates_positive, coordinates_negative, bboxes):
182
- # image: numpy array (uint8)
183
- # Submit task to process pool
184
-
185
- print("image.shape:", image.shape)
186
- print("coordinates_positive:", coordinates_positive)
187
- print("coordinates_negative:", coordinates_negative)
188
- print("bboxes:", bboxes)
189
-
190
- if sam_pool is None:
191
- return None, json.dumps({'error': 'SAM pool not initialized'})
192
-
193
- # Future result
194
- future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes)
195
-
196
- # Wait for result
197
- image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout
198
-
199
- # Convert back to PIL for Gradio
200
- res_pil = Image.fromarray(image_with_alpha_np)
201
-
202
- # Calculate bbox
203
  mask_bbox = get_mask_bbox(mask_np)
204
  if mask_bbox:
205
  x_min, y_min, x_max, y_max = mask_bbox
206
  seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
207
  else:
208
  seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
209
-
210
  return res_pil, json.dumps(seg_bbox)
211
 
212
- # --- Gradio UI ---
213
  with gr.Blocks() as app:
214
  with gr.Row():
215
  gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
216
-
217
  with gr.Tab("Draw & Guess"):
218
  with gr.Row():
219
  dg_input_img = gr.Image(label="Original Image")
@@ -221,7 +146,7 @@ with gr.Blocks() as app:
221
  dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
222
  dg_output = gr.Textbox(label="Prediction Output")
223
  dg_btn = gr.Button("Guess")
224
-
225
  dg_btn.click(
226
  fn=guess,
227
  inputs=[dg_input_img, dg_color_img, dg_edge_img],
@@ -229,20 +154,20 @@ with gr.Blocks() as app:
229
  api_name="guess_prompt",
230
  concurrency_limit=1
231
  )
232
-
233
  with gr.Tab("SAM Segmentation"):
234
  with gr.Row():
235
  sam_input_img = gr.Image(label="Input Image", type="numpy")
236
  sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
237
  sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
238
  sam_bboxes = gr.Textbox(label="BBoxes JSON")
239
-
240
  with gr.Row():
241
  sam_output_img = gr.Image(label="Segmented Image", format="png")
242
  sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
243
-
244
  sam_btn = gr.Button("Segment")
245
-
246
  sam_btn.click(
247
  fn=segment,
248
  inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
@@ -251,15 +176,6 @@ with gr.Blocks() as app:
251
  concurrency_limit=5
252
  )
253
 
 
254
  if __name__ == "__main__":
255
- # Set start method to spawn for CUDA compatibility
256
- multiprocessing.set_start_method('spawn', force=True)
257
-
258
- # Initialize SAM Pool
259
- # Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B)
260
- NUM_SAM_WORKERS = 5
261
- print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...")
262
- sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam)
263
-
264
- # Launch Gradio
265
  app.queue(max_size=40).launch(max_threads=5)
 
1
+ import spaces
2
  import gradio as gr
 
3
  import torch
4
  import numpy as np
5
+ from PIL import Image
6
  import os
7
  import json
 
 
 
 
8
 
 
9
  from MagicQuill import folder_paths
10
  from MagicQuill.llava_new import LLaVAModel
11
  from huggingface_hub import snapshot_download
12
 
 
13
  from segment_anything import sam_model_registry, SamPredictor
14
 
 
15
  hf_token = os.environ.get("HF_TOKEN")
16
  snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
17
  snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
18
 
19
+ print("Initializing LLaVAModel...")
 
 
20
  llavaModel = LLaVAModel()
21
  print("LLaVAModel initialized.")
22
 
23
+ print("Initializing SAM...")
24
+ sam = sam_model_registry['vit_b'](checkpoint='models_v2/sam/sam_vit_b_01ec64.pth')
25
+ sam.to(device='cuda')
26
+ sam_predictor = SamPredictor(sam)
27
+ print("SAM initialized.")
28
+
29
+
30
+ def numpy_to_tensor(numpy_array):
31
+ tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
32
+ return tensor
33
+
34
+
35
+ @spaces.GPU
36
+ def guess(original_image, add_color_image, add_edge_mask):
37
+ original_image_tensor = numpy_to_tensor(original_image)
38
+ add_color_image_tensor = numpy_to_tensor(add_color_image)
39
+ add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
40
+
41
+ description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
42
+
43
+ ans_list = []
44
+ if ans1 and ans1 != "":
45
+ ans_list.append(ans1)
46
+ if ans2 and ans2 != "":
47
+ ans_list.append(ans2)
48
+
49
+ return ", ".join(ans_list)
50
+
51
+
52
+ def get_mask_bbox(mask_np):
53
+ if mask_np.ndim == 3:
54
+ mask_np = mask_np[0]
55
+
56
+ rows = np.any(mask_np, axis=1)
57
+ cols = np.any(mask_np, axis=0)
58
+ if not np.any(rows) or not np.any(cols):
59
+ return None
60
+
61
+ y_min, y_max = np.where(rows)[0][[0, -1]]
62
+ x_min, x_max = np.where(cols)[0][[0, -1]]
63
+ return int(x_min), int(y_min), int(x_max), int(y_max)
64
+
65
+
66
+ @spaces.GPU
67
+ def segment(image, coordinates_positive, coordinates_negative, bboxes):
68
+ print("image.shape:", image.shape)
69
+ print("coordinates_positive:", coordinates_positive)
70
+ print("coordinates_negative:", coordinates_negative)
71
+ print("bboxes:", bboxes)
72
+
73
+ sam_predictor.set_image(image)
74
+
75
  input_point = []
76
  input_label = []
77
+
 
78
  if coordinates_positive:
79
  coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
80
  for p in coords:
81
  input_point.append([p['x'], p['y']])
82
  input_label.append(1)
83
+
84
  if coordinates_negative:
85
  coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
86
  for p in coords:
87
  input_point.append([p['x'], p['y']])
88
  input_label.append(0)
89
 
 
90
  input_box = None
91
  if bboxes:
92
  if isinstance(bboxes, str):
93
  try:
94
  bboxes = json.loads(bboxes)
95
+ except Exception:
96
  pass
97
+
98
  box_list = []
99
  if isinstance(bboxes, list):
100
  for box in bboxes:
101
  box_list.append(list(box))
102
+
103
  if len(box_list) > 0:
104
  input_box = np.array(box_list)
105
 
 
110
  input_point = None
111
  input_label = None
112
 
113
+ masks, scores, logits = sam_predictor.predict(
 
114
  point_coords=input_point,
115
  point_labels=input_label,
116
  box=input_box,
117
  multimask_output=False,
118
  )
119
+
120
  mask_np = masks[0]
 
 
 
121
  if mask_np.dtype == bool:
122
  mask_np = mask_np.astype(np.uint8) * 255
123
  else:
124
  mask_np = (mask_np > 0).astype(np.uint8) * 255
 
 
 
 
 
125
 
126
+ res_pil = Image.fromarray(mask_np)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  mask_bbox = get_mask_bbox(mask_np)
129
  if mask_bbox:
130
  x_min, y_min, x_max, y_max = mask_bbox
131
  seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
132
  else:
133
  seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
134
+
135
  return res_pil, json.dumps(seg_bbox)
136
 
137
+
138
  with gr.Blocks() as app:
139
  with gr.Row():
140
  gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
141
+
142
  with gr.Tab("Draw & Guess"):
143
  with gr.Row():
144
  dg_input_img = gr.Image(label="Original Image")
 
146
  dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
147
  dg_output = gr.Textbox(label="Prediction Output")
148
  dg_btn = gr.Button("Guess")
149
+
150
  dg_btn.click(
151
  fn=guess,
152
  inputs=[dg_input_img, dg_color_img, dg_edge_img],
 
154
  api_name="guess_prompt",
155
  concurrency_limit=1
156
  )
157
+
158
  with gr.Tab("SAM Segmentation"):
159
  with gr.Row():
160
  sam_input_img = gr.Image(label="Input Image", type="numpy")
161
  sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
162
  sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
163
  sam_bboxes = gr.Textbox(label="BBoxes JSON")
164
+
165
  with gr.Row():
166
  sam_output_img = gr.Image(label="Segmented Image", format="png")
167
  sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
168
+
169
  sam_btn = gr.Button("Segment")
170
+
171
  sam_btn.click(
172
  fn=segment,
173
  inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
 
176
  concurrency_limit=5
177
  )
178
 
179
+
180
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
181
  app.queue(max_size=40).launch(max_threads=5)
requirements.txt CHANGED
@@ -14,7 +14,7 @@ anyio==4.4.0
14
  async-timeout==4.0.3
15
  attrs==23.2.0
16
  beautifulsoup4==4.12.3
17
- bitsandbytes==0.43.3
18
  certifi==2024.7.4
19
  cffi==1.16.0
20
  chardet==5.2.0
@@ -33,7 +33,6 @@ einops-exts==0.0.4
33
  embreex==2.17.7.post5
34
  eval-type-backport==0.2.0
35
  exceptiongroup==1.2.2
36
- fastapi
37
  ffmpy==0.4.0
38
  filelock==3.15.4
39
  flatbuffers==24.3.25
@@ -132,11 +131,10 @@ sounddevice==0.4.7
132
  soupsieve==2.5
133
  spandrel==0.3.4
134
  stanza==1.1.1
135
- starlette
136
  svg-path==6.3
137
  svglib==1.5.1
138
  svgwrite==1.4.3
139
- sympy==1.13.1
140
  tabulate==0.9.0
141
  termcolor==2.4.0
142
  threadpoolctl==3.5.0
@@ -151,7 +149,6 @@ tqdm==4.66.5
151
  trampoline==0.1.2
152
  transformers==4.37.2
153
  trimesh==4.4.3
154
- triton==2.1.0
155
  torchsde==0.2.6
156
  typer==0.12.5
157
  typing-extensions==4.12.2
@@ -169,4 +166,8 @@ yacs==0.1.8
169
  yapf==0.40.2
170
  yarl==1.9.4
171
  zipp==3.19.2
172
- git+https://github.com/facebookresearch/segment-anything.git
 
 
 
 
 
14
  async-timeout==4.0.3
15
  attrs==23.2.0
16
  beautifulsoup4==4.12.3
17
+ bitsandbytes
18
  certifi==2024.7.4
19
  cffi==1.16.0
20
  chardet==5.2.0
 
33
  embreex==2.17.7.post5
34
  eval-type-backport==0.2.0
35
  exceptiongroup==1.2.2
 
36
  ffmpy==0.4.0
37
  filelock==3.15.4
38
  flatbuffers==24.3.25
 
131
  soupsieve==2.5
132
  spandrel==0.3.4
133
  stanza==1.1.1
 
134
  svg-path==6.3
135
  svglib==1.5.1
136
  svgwrite==1.4.3
137
+ sympy==1.13.3
138
  tabulate==0.9.0
139
  termcolor==2.4.0
140
  threadpoolctl==3.5.0
 
149
  trampoline==0.1.2
150
  transformers==4.37.2
151
  trimesh==4.4.3
 
152
  torchsde==0.2.6
153
  typer==0.12.5
154
  typing-extensions==4.12.2
 
166
  yapf==0.40.2
167
  yarl==1.9.4
168
  zipp==3.19.2
169
+ git+https://github.com/facebookresearch/segment-anything.git
170
+ starlette<0.38
171
+ fastapi<0.112
172
+ torch==2.8.0
173
+ torchvision==0.23.0