File size: 2,730 Bytes
2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f 68be95b 2fbdc5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import torch
import re
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import gradio as gr
MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B"
# ----------------------------
# Load model (CPU optimized)
# ----------------------------
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, # CPU safe
low_cpu_mem_usage=True
)
model.eval()
# ----------------------------
# Coordinate Extraction
# ----------------------------
def extract_coordinates(text, image_size):
width, height = image_size
match = re.search(r"\(([\d\.]+),\s*([\d\.]+)\)", text)
if match:
x, y = float(match.group(1)), float(match.group(2))
if x <= 1 and y <= 1:
x = int(x * width)
y = int(y * height)
else:
x, y = int(x), int(y)
return (x, y)
match_box = re.search(r"\[([\d\.,\s]+)\]", text)
if match_box:
nums = list(map(float, match_box.group(1).split(",")))
if len(nums) == 4:
x1, y1, x2, y2 = nums
if max(nums) <= 1:
x1, x2 = int(x1 * width), int(x2 * width)
y1, y2 = int(y1 * height), int(y2 * height)
else:
x1, y1, x2, y2 = map(int, nums)
return (x1, y1, x2, y2)
return None
# ----------------------------
# Prediction
# ----------------------------
def predict(image, prompt):
if image is None:
return "Upload image", "No coordinates"
image_pil = Image.fromarray(image).convert("RGB")
width, height = image_pil.size
inputs = processor(
images=image_pil,
text=prompt,
return_tensors="pt"
)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=150
)
result = processor.batch_decode(output, skip_special_tokens=True)[0]
coords = extract_coordinates(result, (width, height))
coord_text = (
f"{coords} (origin: top-left, x→right, y↓)"
if coords else "No coordinates detected"
)
return result, coord_text
# ----------------------------
# UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("# UI-TARS CPU Demo (Slow ⚠️)")
with gr.Row():
image_input = gr.Image(type="numpy", label="Image")
text_input = gr.Textbox(label="Prompt")
btn = gr.Button("Run")
output_text = gr.Textbox(label="Model Output")
coord_output = gr.Textbox(label="Coordinates")
btn.click(
fn=predict,
inputs=[image_input, text_input],
outputs=[output_text, coord_output]
)
demo.launch() |