songtianhui commited on
Commit
5cdcb65
·
1 Parent(s): 56399ad

add readme

Browse files
Files changed (1) hide show
  1. README.md +156 -0
README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ ---
5
+ # Towards Pixel-level VLM Perception via Simple Points Prediction
6
+
7
+ <div align="center">
8
+ <a href="">
9
+ <b>📄 Tech Report</b>
10
+ </a> &nbsp;|&nbsp;
11
+ <a href="https://github.com/songtianhui/SimpleSeg">
12
+ <b>📄 Github</b>
13
+ </div>
14
+
15
+
16
+ ## Introduction
17
+
18
+ > [!Note]
19
+ > This is Kimi-VL version of SimpleSeg, an architecture with 16B-A3B paramters.
20
+
21
+
22
+ We present **SimpleSeg**, **a strikingly simple yet highly effective approach to endow Multimodal Large Language Models (MLLMs) with native pixel-level perception**.
23
+ Our method reframes segmentation as a simple sequence generation problem: the model directly predicts **sequence of points** (textual coordinates) delineating object boundaries, entirely within its language space.
24
+ To achieve high fidelity, we introduce a two-stage SFT→RL training pipeline, where Reinforcement Learning with an IoU-based reward refines the point sequences to accurately match ground-truth contours.
25
+ We find that **the standard MLLM architecture possesses a strong, inherent capacity for low-level perception** that can be unlocked without any specialized architecture.
26
+ On segmentation benchmarks, SimpleSeg achieves performance that is comparable to, and often surpasses, methods relying on complex, task-specific designs.
27
+ This work lays out that precise spatial understanding can emerge from simple point prediction, challenging the prevailing need for auxiliary components and paving the way for more unified and capable VLMs.
28
+
29
+ ## Method
30
+
31
+ ![](method.png)
32
+
33
+ In this work, we explore the limits of MLLM pixel-level perception by predicting the next point in a contour with the simplest approach possible.
34
+ Without introducing any complex architectures or special patterns, we show how even minimalistic point prediction can achieve effective segmentation at the pixel level.
35
+
36
+ ## Key Benefits
37
+
38
+ - **Simplicity**: SimpleSeg requires no specialized modules and adheres to the standard MLLM architecture, it can be seamlessly and efficiently integrated as a new, core pre-training task for foundation models, similar to visual grounding.
39
+ - **Task Generality**: By framing segmentation as a text-generation problem, our approach is inherently flexible. The model can be easily adapted to a wide range of vision-language tasks that require precise spatial localization.
40
+ - **Interpretable Output**: The model generates explicit, human-readable coordinate sequences instead of dense pixel masks. This transparency simplifies debugging and makes the output directly usable for downstream applications like interactive editing or tool use.
41
+
42
+ ## Performance
43
+
44
+ - **Referring Expression Segmentation** results
45
+
46
+ | Methods | refCOCO | | | refCOCO+ | | | refCOCOg | | Avg. |
47
+ |--------------------------------|---------|----------|----------|----------|----------|----------|----------|----------|-------|
48
+ | | val | testA | testB | val | testA | testB | val | test | |
49
+ | **Decoder-based Models** | | | | | | | | | |
50
+ | NEXT-Chat | 74.7 | 78.9 | 69.5 | 65.1 | 71.9 | 56.7 | 67.0 | 67.0 | 68.9 |
51
+ | LISA | 74.9 | 79.1 | 72.3 | 65.1 | 70.8 | 58.1 | 67.9 | 70.6 | 69.9 |
52
+ | PixelLM | 73.0 | 76.5 | 68.2 | 66.3 | 71.7 | 58.3 | 69.3 | 70.5 | 69.2 |
53
+ | AnyRef | 76.9 | 79.9 | 74.2 | 70.3 | 73.5 | 61.8 | 70.0 | 70.7 | 72.2 |
54
+ | GSVA | 77.2 | 78.9 | 73.5 | 65.9 | 69.6 | 59.8 | 72.7 | 73.3 | 71.4 |
55
+ | LaSagNA | 76.8 | 78.7 | 73.8 | 66.4 | 70.6 | 60.1 | 70.6 | 71.9 | 71.1 |
56
+ | Groundhog | 78.5 | 79.9 | 75.7 | 70.5 | 75.0 | 64.9 | 74.1 | 74.6 | 74.2 |
57
+ | Text4Seg (w/ SAM) | 79.2 | 81.7 | 75.6 | 72.8 | 77.9 | 66.5 | 74.0 | 75.3 | 75.4 |
58
+ | **Decoder-free Models** | | | | | | | | | |
59
+ | Text4Seg | 74.7 | 77.4 | 71.6 | 68.5 | 73.6 | 62.9 | 70.7 | 71.6 | 71.4 |
60
+ | **SimpleSeg**-Qwen2.5-VL | 80.9 | 77.8 | 75.2 | 72.4 | 77.3 | 66.1 | 73.3 | 74.1 | 74.6 |
61
+ | **SimpleSeg**-Kimi-VL | 80.0 | 80.6 | 76.2 | 70.4 | 76.2 | 67.1 | 72.8 | 74.7 | 74.8 |
62
+
63
+
64
+ - **Referring Expression Comprehension** results
65
+
66
+ | Methods | refCOCO | | | refCOCO+ | | | refCOCOg | | Avg. |
67
+ |------------------|---------|----------|----------|----------|----------|----------|----------|----------|-------|
68
+ | | val | testA | testB | val | testA | testB | val | test | |
69
+ | **Decoder-based Models** | | | | | | | | | |
70
+ | LISA | 85.4 | 88.8 | 82.6 | 74.2 | 79.5 | 68.4 | 79.3 | 80.4 | 79.8 |
71
+ | GSVA | 86.3 | 89.2 | 83.8 | 72.8 | 78.8 | 68.0 | 81.6 | 81.8 | 80.3 |
72
+ | NEXT-Chat | 85.5 | 90.0 | 77.9 | 77.2 | 84.5 | 68.0 | 80.1 | 79.8 | 80.4 |
73
+ | PixelLM | 89.8 | 92.2 | 86.4 | 83.2 | 87.0 | 78.9 | 84.6 | 86.0 | 86.0 |
74
+ | Text4Seg (w/ SAM)| 90.3 | 93.4 | 87.5 | 85.2 | 89.9 | 79.5 | 85.4 | 85.4 | 87.1 |
75
+ | **Decoder-free Models** | | | | | | | | | |
76
+ | Text4Seg | 88.3 | 91.4 | 85.8 | 83.5 | 88.2 | 77.9 | 82.4 | 82.5 | 85.0 |
77
+ | **SimpleSeg**-Qwen2.5-VL | 90.2| 92.9 | 86.1 | 84.6 | 90.5 | 79.0 | 84.9 | 85.6 | 86.7 |
78
+ | **SimpleSeg**-Kimi-VL | 91.3| 92.1 | 87.1 | 82.6 | 88.3 | 79.3 | 84.6 | 86.3 | 86.5 |
79
+
80
+
81
+ # Model Usage
82
+
83
+ ## Inference with 🤗 Hugging Face Transformers
84
+
85
+ It is recommended to use python=3.10, torch>=2.1.0, and transformers=4.48.2 as the development environment.
86
+
87
+ ```python
88
+ from PIL import Image
89
+ from transformers import AutoModelForCausalLM, AutoProcessor
90
+ model_path = "simpleseganonymous/SimpleSeg"
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ model_path,
93
+ torch_dtype="auto",
94
+ device_map="auto",
95
+ trust_remote_code=True,
96
+ )
97
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
98
+ image_path = "./figures/octopus.png"
99
+ image = Image.open(image_path)
100
+ messages = [
101
+ {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": "Output the polygon coordinates of octopus in the image."}]}
102
+ ]
103
+ text = processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
104
+ inputs = processor(images=image, text=text, return_tensors="pt", padding=True, truncation=True).to(model.device)
105
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
106
+ generated_ids_trimmed = [
107
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
108
+ ]
109
+ response = processor.batch_decode(
110
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
111
+ )[0]
112
+ print(response)
113
+ ```
114
+
115
+
116
+ ## Decode the polygons and masks from the response string
117
+
118
+ ```python
119
+ import re
120
+ import json
121
+ import pycocotools.mask as mask_utils
122
+
123
+ class RegexPatterns:
124
+ BOXED_PATTERN = r'\\boxed\{([^}]*)\}'
125
+ BLOCK_PATTERN = r'^```$\r?\n(.*?)\r?\n^```$'
126
+
127
+ NON_NEGATIVE_FLOAT_PATTERN = (
128
+ r'(?:[1-9]\d*\.\d+|0\.\d+|\d+)'
129
+ )
130
+
131
+ BBOX_PATTERN = rf'\[\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*\]'
132
+ POINT_PATTERN = (
133
+ rf'\[\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*\]'
134
+ )
135
+
136
+ POLYGON_PATTERN = rf'\[\s*{POINT_PATTERN}(?:\s*,\s*{POINT_PATTERN})*\s*\]'
137
+
138
+ polygon_matches = [
139
+ m.group(0) for m in re.finditer(RegexPatterns.POLYGON_PATTERN, response, re.DOTALL)
140
+ ]
141
+ pred_polygons = []
142
+ for polygon_match in polygon_matches:
143
+ polygon = json.loads(polygon_match)
144
+ pred_polygons.append(polygon)
145
+
146
+ pred_masks = []
147
+ for pred_polygon in pred_polygons:
148
+ pred_polygon = np.array(pred_polygon) * np.array([width, height])
149
+ rle = mask_utils.frPyObjects(pred_polygon.reshape((1, -1)).tolist(), height, width)
150
+ mask = mask_utils.decode(rle)
151
+ mask = np.sum(mask, axis=2, keepdims=True)
152
+ pred_masks.append(mask)
153
+ pred_mask = np.sum(pred_masks, axis=0)
154
+ pred_mask = pred_mask.sum(axis=2)
155
+ pred_mask = (pred_mask > 0).astype(np.uint8)
156
+ ```