import spaces import gradio as gr import torch import numpy as np from PIL import Image import os import json from MagicQuill import folder_paths from MagicQuill.llava_new import LLaVAModel from huggingface_hub import snapshot_download from segment_anything import sam_model_registry, SamPredictor hf_token = os.environ.get("HF_TOKEN") snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models") snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token) print("Initializing LLaVAModel...") llavaModel = LLaVAModel() print("LLaVAModel initialized.") print("Initializing SAM...") sam = sam_model_registry['vit_b'](checkpoint='models_v2/sam/sam_vit_b_01ec64.pth') sam.to(device='cuda') sam_predictor = SamPredictor(sam) print("SAM initialized.") def numpy_to_tensor(numpy_array): tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255. return tensor @spaces.GPU def guess(original_image, add_color_image, add_edge_mask): original_image_tensor = numpy_to_tensor(original_image) add_color_image_tensor = numpy_to_tensor(add_color_image) add_edge_mask_tensor = numpy_to_tensor(add_edge_mask) description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor) ans_list = [] if ans1 and ans1 != "": ans_list.append(ans1) if ans2 and ans2 != "": ans_list.append(ans2) return ", ".join(ans_list) def get_mask_bbox(mask_np): if mask_np.ndim == 3: mask_np = mask_np[0] rows = np.any(mask_np, axis=1) cols = np.any(mask_np, axis=0) if not np.any(rows) or not np.any(cols): return None y_min, y_max = np.where(rows)[0][[0, -1]] x_min, x_max = np.where(cols)[0][[0, -1]] return int(x_min), int(y_min), int(x_max), int(y_max) @spaces.GPU def segment(image, coordinates_positive, coordinates_negative, bboxes): print("image.shape:", image.shape) print("coordinates_positive:", coordinates_positive) print("coordinates_negative:", coordinates_negative) print("bboxes:", bboxes) sam_predictor.set_image(image) input_point = [] input_label = [] if coordinates_positive: coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive for p in coords: input_point.append([p['x'], p['y']]) input_label.append(1) if coordinates_negative: coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative for p in coords: input_point.append([p['x'], p['y']]) input_label.append(0) input_box = None if bboxes: if isinstance(bboxes, str): try: bboxes = json.loads(bboxes) except Exception: pass box_list = [] if isinstance(bboxes, list): for box in bboxes: box_list.append(list(box)) if len(box_list) > 0: input_box = np.array(box_list) if len(input_point) > 0: input_point = np.array(input_point) input_label = np.array(input_label) else: input_point = None input_label = None masks, scores, logits = sam_predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box, multimask_output=False, ) mask_np = masks[0] if mask_np.dtype == bool: mask_np = mask_np.astype(np.uint8) * 255 else: mask_np = (mask_np > 0).astype(np.uint8) * 255 res_pil = Image.fromarray(mask_np) mask_bbox = get_mask_bbox(mask_np) if mask_bbox: x_min, y_min, x_max, y_max = mask_bbox seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} else: seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} return res_pil, json.dumps(seg_bbox) with gr.Blocks() as app: with gr.Row(): gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)") with gr.Tab("Draw & Guess"): with gr.Row(): dg_input_img = gr.Image(label="Original Image") dg_color_img = gr.Image(label="Colored Image") dg_edge_img = gr.Image(image_mode="L", label="Edge Mask") dg_output = gr.Textbox(label="Prediction Output") dg_btn = gr.Button("Guess") dg_btn.click( fn=guess, inputs=[dg_input_img, dg_color_img, dg_edge_img], outputs=dg_output, api_name="guess_prompt", concurrency_limit=1 ) with gr.Tab("SAM Segmentation"): with gr.Row(): sam_input_img = gr.Image(label="Input Image", type="numpy") sam_pos_coords = gr.Textbox(label="Pos Coords JSON") sam_neg_coords = gr.Textbox(label="Neg Coords JSON") sam_bboxes = gr.Textbox(label="BBoxes JSON") with gr.Row(): sam_output_img = gr.Image(label="Segmented Image", format="png") sam_output_bbox = gr.Textbox(label="Mask BBox JSON") sam_btn = gr.Button("Segment") sam_btn.click( fn=segment, inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes], outputs=[sam_output_img, sam_output_bbox], api_name="segment", concurrency_limit=5 ) if __name__ == "__main__": app.queue(max_size=40).launch(max_threads=5)