Spaces:
Running on Zero
Running on Zero
Adapt to ZeroGPU: load models at module scope; @spaces.GPU on guess/segment; drop ProcessPool + triton pin
046eb30 | import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| import json | |
| from MagicQuill import folder_paths | |
| from MagicQuill.llava_new import LLaVAModel | |
| from huggingface_hub import snapshot_download | |
| from segment_anything import sam_model_registry, SamPredictor | |
| hf_token = os.environ.get("HF_TOKEN") | |
| snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models") | |
| snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token) | |
| print("Initializing LLaVAModel...") | |
| llavaModel = LLaVAModel() | |
| print("LLaVAModel initialized.") | |
| print("Initializing SAM...") | |
| sam = sam_model_registry['vit_b'](checkpoint='models_v2/sam/sam_vit_b_01ec64.pth') | |
| sam.to(device='cuda') | |
| sam_predictor = SamPredictor(sam) | |
| print("SAM initialized.") | |
| def numpy_to_tensor(numpy_array): | |
| tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255. | |
| return tensor | |
| def guess(original_image, add_color_image, add_edge_mask): | |
| original_image_tensor = numpy_to_tensor(original_image) | |
| add_color_image_tensor = numpy_to_tensor(add_color_image) | |
| add_edge_mask_tensor = numpy_to_tensor(add_edge_mask) | |
| description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor) | |
| ans_list = [] | |
| if ans1 and ans1 != "": | |
| ans_list.append(ans1) | |
| if ans2 and ans2 != "": | |
| ans_list.append(ans2) | |
| return ", ".join(ans_list) | |
| def get_mask_bbox(mask_np): | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np[0] | |
| rows = np.any(mask_np, axis=1) | |
| cols = np.any(mask_np, axis=0) | |
| if not np.any(rows) or not np.any(cols): | |
| return None | |
| y_min, y_max = np.where(rows)[0][[0, -1]] | |
| x_min, x_max = np.where(cols)[0][[0, -1]] | |
| return int(x_min), int(y_min), int(x_max), int(y_max) | |
| def segment(image, coordinates_positive, coordinates_negative, bboxes): | |
| print("image.shape:", image.shape) | |
| print("coordinates_positive:", coordinates_positive) | |
| print("coordinates_negative:", coordinates_negative) | |
| print("bboxes:", bboxes) | |
| sam_predictor.set_image(image) | |
| input_point = [] | |
| input_label = [] | |
| if coordinates_positive: | |
| coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive | |
| for p in coords: | |
| input_point.append([p['x'], p['y']]) | |
| input_label.append(1) | |
| if coordinates_negative: | |
| coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative | |
| for p in coords: | |
| input_point.append([p['x'], p['y']]) | |
| input_label.append(0) | |
| input_box = None | |
| if bboxes: | |
| if isinstance(bboxes, str): | |
| try: | |
| bboxes = json.loads(bboxes) | |
| except Exception: | |
| pass | |
| box_list = [] | |
| if isinstance(bboxes, list): | |
| for box in bboxes: | |
| box_list.append(list(box)) | |
| if len(box_list) > 0: | |
| input_box = np.array(box_list) | |
| if len(input_point) > 0: | |
| input_point = np.array(input_point) | |
| input_label = np.array(input_label) | |
| else: | |
| input_point = None | |
| input_label = None | |
| masks, scores, logits = sam_predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| box=input_box, | |
| multimask_output=False, | |
| ) | |
| mask_np = masks[0] | |
| if mask_np.dtype == bool: | |
| mask_np = mask_np.astype(np.uint8) * 255 | |
| else: | |
| mask_np = (mask_np > 0).astype(np.uint8) * 255 | |
| res_pil = Image.fromarray(mask_np) | |
| mask_bbox = get_mask_bbox(mask_np) | |
| if mask_bbox: | |
| x_min, y_min, x_max, y_max = mask_bbox | |
| seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} | |
| else: | |
| seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} | |
| return res_pil, json.dumps(seg_bbox) | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)") | |
| with gr.Tab("Draw & Guess"): | |
| with gr.Row(): | |
| dg_input_img = gr.Image(label="Original Image") | |
| dg_color_img = gr.Image(label="Colored Image") | |
| dg_edge_img = gr.Image(image_mode="L", label="Edge Mask") | |
| dg_output = gr.Textbox(label="Prediction Output") | |
| dg_btn = gr.Button("Guess") | |
| dg_btn.click( | |
| fn=guess, | |
| inputs=[dg_input_img, dg_color_img, dg_edge_img], | |
| outputs=dg_output, | |
| api_name="guess_prompt", | |
| concurrency_limit=1 | |
| ) | |
| with gr.Tab("SAM Segmentation"): | |
| with gr.Row(): | |
| sam_input_img = gr.Image(label="Input Image", type="numpy") | |
| sam_pos_coords = gr.Textbox(label="Pos Coords JSON") | |
| sam_neg_coords = gr.Textbox(label="Neg Coords JSON") | |
| sam_bboxes = gr.Textbox(label="BBoxes JSON") | |
| with gr.Row(): | |
| sam_output_img = gr.Image(label="Segmented Image", format="png") | |
| sam_output_bbox = gr.Textbox(label="Mask BBox JSON") | |
| sam_btn = gr.Button("Segment") | |
| sam_btn.click( | |
| fn=segment, | |
| inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes], | |
| outputs=[sam_output_img, sam_output_bbox], | |
| api_name="segment", | |
| concurrency_limit=5 | |
| ) | |
| if __name__ == "__main__": | |
| app.queue(max_size=40).launch(max_threads=5) | |