File size: 2,730 Bytes
2fbdc5f
 
68be95b
2fbdc5f
 
 
 
 
68be95b
 
 
2fbdc5f
68be95b
 
2fbdc5f
68be95b
 
2fbdc5f
 
68be95b
 
2fbdc5f
 
 
 
 
 
 
 
 
 
 
 
 
 
68be95b
2fbdc5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68be95b
2fbdc5f
 
 
68be95b
2fbdc5f
 
 
 
 
 
 
 
68be95b
2fbdc5f
 
 
 
68be95b
2fbdc5f
 
 
 
 
 
68be95b
 
 
 
2fbdc5f
 
 
 
 
68be95b
2fbdc5f
 
68be95b
2fbdc5f
 
68be95b
 
2fbdc5f
68be95b
2fbdc5f
 
68be95b
2fbdc5f
68be95b
2fbdc5f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import re
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import gradio as gr

MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B"

# ----------------------------
# Load model (CPU optimized)
# ----------------------------
processor = AutoProcessor.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,        # CPU safe
    low_cpu_mem_usage=True
)

model.eval()

# ----------------------------
# Coordinate Extraction
# ----------------------------
def extract_coordinates(text, image_size):
    width, height = image_size

    match = re.search(r"\(([\d\.]+),\s*([\d\.]+)\)", text)
    if match:
        x, y = float(match.group(1)), float(match.group(2))

        if x <= 1 and y <= 1:
            x = int(x * width)
            y = int(y * height)
        else:
            x, y = int(x), int(y)

        return (x, y)

    match_box = re.search(r"\[([\d\.,\s]+)\]", text)
    if match_box:
        nums = list(map(float, match_box.group(1).split(",")))
        if len(nums) == 4:
            x1, y1, x2, y2 = nums

            if max(nums) <= 1:
                x1, x2 = int(x1 * width), int(x2 * width)
                y1, y2 = int(y1 * height), int(y2 * height)
            else:
                x1, y1, x2, y2 = map(int, nums)

            return (x1, y1, x2, y2)

    return None


# ----------------------------
# Prediction
# ----------------------------
def predict(image, prompt):
    if image is None:
        return "Upload image", "No coordinates"

    image_pil = Image.fromarray(image).convert("RGB")
    width, height = image_pil.size

    inputs = processor(
        images=image_pil,
        text=prompt,
        return_tensors="pt"
    )

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=150
        )

    result = processor.batch_decode(output, skip_special_tokens=True)[0]

    coords = extract_coordinates(result, (width, height))

    coord_text = (
        f"{coords} (origin: top-left, x→right, y↓)"
        if coords else "No coordinates detected"
    )

    return result, coord_text


# ----------------------------
# UI
# ----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# UI-TARS CPU Demo (Slow ⚠️)")

    with gr.Row():
        image_input = gr.Image(type="numpy", label="Image")
        text_input = gr.Textbox(label="Prompt")

    btn = gr.Button("Run")

    output_text = gr.Textbox(label="Model Output")
    coord_output = gr.Textbox(label="Coordinates")

    btn.click(
        fn=predict,
        inputs=[image_input, text_input],
        outputs=[output_text, coord_output]
    )

demo.launch()