Spaces:
Running on Zero
Running on Zero
File size: 5,543 Bytes
046eb30 0e84795 046eb30 191dbfa 0e84795 191dbfa 0e84795 191dbfa 0e84795 046eb30 0e84795 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 644a908 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 046eb30 191dbfa 644a908 0f883aa 191dbfa 046eb30 191dbfa 0f883aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
import os
import json
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download
from segment_anything import sam_model_registry, SamPredictor
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
print("Initializing LLaVAModel...")
llavaModel = LLaVAModel()
print("LLaVAModel initialized.")
print("Initializing SAM...")
sam = sam_model_registry['vit_b'](checkpoint='models_v2/sam/sam_vit_b_01ec64.pth')
sam.to(device='cuda')
sam_predictor = SamPredictor(sam)
print("SAM initialized.")
def numpy_to_tensor(numpy_array):
tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
return tensor
@spaces.GPU
def guess(original_image, add_color_image, add_edge_mask):
original_image_tensor = numpy_to_tensor(original_image)
add_color_image_tensor = numpy_to_tensor(add_color_image)
add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
ans_list = []
if ans1 and ans1 != "":
ans_list.append(ans1)
if ans2 and ans2 != "":
ans_list.append(ans2)
return ", ".join(ans_list)
def get_mask_bbox(mask_np):
if mask_np.ndim == 3:
mask_np = mask_np[0]
rows = np.any(mask_np, axis=1)
cols = np.any(mask_np, axis=0)
if not np.any(rows) or not np.any(cols):
return None
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
return int(x_min), int(y_min), int(x_max), int(y_max)
@spaces.GPU
def segment(image, coordinates_positive, coordinates_negative, bboxes):
print("image.shape:", image.shape)
print("coordinates_positive:", coordinates_positive)
print("coordinates_negative:", coordinates_negative)
print("bboxes:", bboxes)
sam_predictor.set_image(image)
input_point = []
input_label = []
if coordinates_positive:
coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(1)
if coordinates_negative:
coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(0)
input_box = None
if bboxes:
if isinstance(bboxes, str):
try:
bboxes = json.loads(bboxes)
except Exception:
pass
box_list = []
if isinstance(bboxes, list):
for box in bboxes:
box_list.append(list(box))
if len(box_list) > 0:
input_box = np.array(box_list)
if len(input_point) > 0:
input_point = np.array(input_point)
input_label = np.array(input_label)
else:
input_point = None
input_label = None
masks, scores, logits = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
mask_np = masks[0]
if mask_np.dtype == bool:
mask_np = mask_np.astype(np.uint8) * 255
else:
mask_np = (mask_np > 0).astype(np.uint8) * 255
res_pil = Image.fromarray(mask_np)
mask_bbox = get_mask_bbox(mask_np)
if mask_bbox:
x_min, y_min, x_max, y_max = mask_bbox
seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
else:
seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
return res_pil, json.dumps(seg_bbox)
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
with gr.Tab("Draw & Guess"):
with gr.Row():
dg_input_img = gr.Image(label="Original Image")
dg_color_img = gr.Image(label="Colored Image")
dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
dg_output = gr.Textbox(label="Prediction Output")
dg_btn = gr.Button("Guess")
dg_btn.click(
fn=guess,
inputs=[dg_input_img, dg_color_img, dg_edge_img],
outputs=dg_output,
api_name="guess_prompt",
concurrency_limit=1
)
with gr.Tab("SAM Segmentation"):
with gr.Row():
sam_input_img = gr.Image(label="Input Image", type="numpy")
sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
sam_bboxes = gr.Textbox(label="BBoxes JSON")
with gr.Row():
sam_output_img = gr.Image(label="Segmented Image", format="png")
sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
sam_btn = gr.Button("Segment")
sam_btn.click(
fn=segment,
inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
outputs=[sam_output_img, sam_output_bbox],
api_name="segment",
concurrency_limit=5
)
if __name__ == "__main__":
app.queue(max_size=40).launch(max_threads=5)
|