simpleseganonymous commited on
Commit
97af7df
·
verified ·
1 Parent(s): d0a3579

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -1
README.md CHANGED
@@ -60,4 +60,82 @@ Without introducing any complex architectures or special patterns, we show how e
60
  | Text4Seg (w/ SAM)| 90.3 | 93.4 | 87.5 | 85.2 | 89.9 | 79.5 | 85.4 | 85.4 | 87.1 |
61
  | **Decoder-free Models** | | | | | | | | | |
62
  | Text4Seg | 88.3 | 91.4 | 85.8 | 83.5 | 88.2 | 77.9 | 82.4 | 82.5 | 85.0 |
63
- | **SimpleSeg** | 90.5 | 92.9 | 86.8 | 85.3 | 89.5 | 80.2 | 86.1 | 86.5 | 87.2 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  | Text4Seg (w/ SAM)| 90.3 | 93.4 | 87.5 | 85.2 | 89.9 | 79.5 | 85.4 | 85.4 | 87.1 |
61
  | **Decoder-free Models** | | | | | | | | | |
62
  | Text4Seg | 88.3 | 91.4 | 85.8 | 83.5 | 88.2 | 77.9 | 82.4 | 82.5 | 85.0 |
63
+ | **SimpleSeg** | 90.5 | 92.9 | 86.8 | 85.3 | 89.5 | 80.2 | 86.1 | 86.5 | 87.2 |
64
+
65
+
66
+ # Model Usage
67
+
68
+
69
+ ## Inference with 🤗 Hugging Face Transformers
70
+
71
+ It is recommended to use python=3.10, torch>=2.1.0, and transformers=4.48.2 as the development environment.
72
+
73
+ ```python
74
+ from PIL import Image
75
+ from transformers import AutoModelForCausalLM, AutoProcessor
76
+ model_path = "simpleseganonymous/SimpleSeg"
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ model_path,
79
+ torch_dtype="auto",
80
+ device_map="auto",
81
+ trust_remote_code=True,
82
+ )
83
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
84
+ image_path = "./figures/octopus.png"
85
+ image = Image.open(image_path)
86
+ messages = [
87
+ {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": "Output the polygon coordinates of octopus in the image."}]}
88
+ ]
89
+ text = processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
90
+ inputs = processor(images=image, text=text, return_tensors="pt", padding=True, truncation=True).to(model.device)
91
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
92
+ generated_ids_trimmed = [
93
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
94
+ ]
95
+ response = processor.batch_decode(
96
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
97
+ )[0]
98
+ print(response)
99
+
100
+ ```
101
+
102
+ ## Decode the polygons and masks from the response string
103
+
104
+ ```python
105
+ import re
106
+ import pycocotools.mask as mask_utils
107
+
108
+ class RegexPatterns:
109
+ BOXED_PATTERN = r'\\boxed\{([^}]*)\}'
110
+ BLOCK_PATTERN = r'^```$\r?\n(.*?)\r?\n^```$'
111
+
112
+ NON_NEGATIVE_FLOAT_PATTERN = (
113
+ r'(?:[1-9]\d*\.\d+|0\.\d+|\d+)'
114
+ )
115
+
116
+ BBOX_PATTERN = rf'\[\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*\]'
117
+ POINT_PATTERN = (
118
+ rf'\[\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*,\s*({NON_NEGATIVE_FLOAT_PATTERN})\s*\]'
119
+ )
120
+
121
+ POLYGON_PATTERN = rf'\[\s*{POINT_PATTERN}(?:\s*,\s*{POINT_PATTERN})*\s*\]'
122
+
123
+ polygon_matches = [
124
+ m.group(0) for m in re.finditer(RegexPatterns.POLYGON_PATTERN, response, re.DOTALL)
125
+ ]
126
+ pred_polygons = []
127
+ for polygon_match in polygon_matches:
128
+ polygon = json.loads(polygon_match)
129
+ pred_polygons.append(polygon)
130
+
131
+ pred_masks = []
132
+ for pred_polygon in pred_polygons:
133
+ pred_polygon = np.array(pred_polygon) * np.array([width, height])
134
+ rle = mask_utils.frPyObjects(pred_polygon.reshape((1, -1)).tolist(), height, width)
135
+ mask = mask_utils.decode(rle)
136
+ mask = np.sum(mask, axis=2, keepdims=True)
137
+ pred_masks.append(mask)
138
+ pred_mask = np.sum(pred_masks, axis=0)
139
+ pred_mask = pred_mask.sum(axis=2)
140
+ pred_mask = (pred_mask > 0).astype(np.uint8)
141
+ ```