apolinario's picture
Adapt to ZeroGPU: load models at module scope; @spaces.GPU on guess/segment; drop ProcessPool + triton pin
046eb30
Raw
History Blame Contribute Delete
5.54 kB
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
import os
import json
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download
from segment_anything import sam_model_registry, SamPredictor
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
print("Initializing LLaVAModel...")
llavaModel = LLaVAModel()
print("LLaVAModel initialized.")
print("Initializing SAM...")
sam = sam_model_registry['vit_b'](checkpoint='models_v2/sam/sam_vit_b_01ec64.pth')
sam.to(device='cuda')
sam_predictor = SamPredictor(sam)
print("SAM initialized.")
def numpy_to_tensor(numpy_array):
tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
return tensor
@spaces.GPU
def guess(original_image, add_color_image, add_edge_mask):
original_image_tensor = numpy_to_tensor(original_image)
add_color_image_tensor = numpy_to_tensor(add_color_image)
add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
ans_list = []
if ans1 and ans1 != "":
ans_list.append(ans1)
if ans2 and ans2 != "":
ans_list.append(ans2)
return ", ".join(ans_list)
def get_mask_bbox(mask_np):
if mask_np.ndim == 3:
mask_np = mask_np[0]
rows = np.any(mask_np, axis=1)
cols = np.any(mask_np, axis=0)
if not np.any(rows) or not np.any(cols):
return None
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
return int(x_min), int(y_min), int(x_max), int(y_max)
@spaces.GPU
def segment(image, coordinates_positive, coordinates_negative, bboxes):
print("image.shape:", image.shape)
print("coordinates_positive:", coordinates_positive)
print("coordinates_negative:", coordinates_negative)
print("bboxes:", bboxes)
sam_predictor.set_image(image)
input_point = []
input_label = []
if coordinates_positive:
coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(1)
if coordinates_negative:
coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(0)
input_box = None
if bboxes:
if isinstance(bboxes, str):
try:
bboxes = json.loads(bboxes)
except Exception:
pass
box_list = []
if isinstance(bboxes, list):
for box in bboxes:
box_list.append(list(box))
if len(box_list) > 0:
input_box = np.array(box_list)
if len(input_point) > 0:
input_point = np.array(input_point)
input_label = np.array(input_label)
else:
input_point = None
input_label = None
masks, scores, logits = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
mask_np = masks[0]
if mask_np.dtype == bool:
mask_np = mask_np.astype(np.uint8) * 255
else:
mask_np = (mask_np > 0).astype(np.uint8) * 255
res_pil = Image.fromarray(mask_np)
mask_bbox = get_mask_bbox(mask_np)
if mask_bbox:
x_min, y_min, x_max, y_max = mask_bbox
seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
else:
seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
return res_pil, json.dumps(seg_bbox)
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
with gr.Tab("Draw & Guess"):
with gr.Row():
dg_input_img = gr.Image(label="Original Image")
dg_color_img = gr.Image(label="Colored Image")
dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
dg_output = gr.Textbox(label="Prediction Output")
dg_btn = gr.Button("Guess")
dg_btn.click(
fn=guess,
inputs=[dg_input_img, dg_color_img, dg_edge_img],
outputs=dg_output,
api_name="guess_prompt",
concurrency_limit=1
)
with gr.Tab("SAM Segmentation"):
with gr.Row():
sam_input_img = gr.Image(label="Input Image", type="numpy")
sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
sam_bboxes = gr.Textbox(label="BBoxes JSON")
with gr.Row():
sam_output_img = gr.Image(label="Segmented Image", format="png")
sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
sam_btn = gr.Button("Segment")
sam_btn.click(
fn=segment,
inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
outputs=[sam_output_img, sam_output_bbox],
api_name="segment",
concurrency_limit=5
)
if __name__ == "__main__":
app.queue(max_size=40).launch(max_threads=5)