Upload folder using huggingface_hub
Browse files- geopixel.py +13 -6
geopixel.py
CHANGED
|
@@ -12,6 +12,7 @@ from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM
|
|
| 12 |
from model.IXC.modeling_internlm2 import InternLM2Model
|
| 13 |
from model.sam2.build_sam import build_sam2_hf
|
| 14 |
from model.sam2.utils.transforms import SAM2Transforms
|
|
|
|
| 15 |
try:
|
| 16 |
from transformers.generation.streamers import BaseStreamer
|
| 17 |
except: # noqa # pylint: disable=bare-except
|
|
@@ -93,8 +94,10 @@ class GeoPixelMetaModel:
|
|
| 93 |
(128, 128),
|
| 94 |
(64, 64),
|
| 95 |
]
|
|
|
|
| 96 |
for param in self.visual_model.parameters():
|
| 97 |
param.requires_grad = False
|
|
|
|
| 98 |
if config.train_mask_decoder:
|
| 99 |
self.visual_model.sam_mask_decoder.train()
|
| 100 |
for param in self.visual_model.sam_mask_decoder.parameters():
|
|
@@ -195,6 +198,8 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
|
|
| 195 |
samples = kwargs.get('samples', None)
|
| 196 |
if samples and samples['data_type'][0] == 'grounding':
|
| 197 |
kwargs['output_hidden_states'] = True
|
|
|
|
|
|
|
| 198 |
torch.cuda.empty_cache()
|
| 199 |
outputs = super().forward(**kwargs)
|
| 200 |
|
|
@@ -246,9 +251,6 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
|
|
| 246 |
low_res_masks,
|
| 247 |
ori_hw[i],
|
| 248 |
)
|
| 249 |
-
|
| 250 |
-
# pred_masks = pred_masks.squeeze(0)
|
| 251 |
-
# all_pred_masks.append(pred_masks)
|
| 252 |
all_pred_masks.append(pred_masks[:, 0])
|
| 253 |
|
| 254 |
|
|
@@ -320,27 +322,32 @@ class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM):
|
|
| 320 |
hd_num: int = 9,
|
| 321 |
history: List[Tuple[str, str]] = [],
|
| 322 |
max_new_tokens: int = 1024,
|
|
|
|
| 323 |
**kwargs,
|
| 324 |
):
|
| 325 |
with torch.no_grad():
|
| 326 |
inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
|
| 327 |
-
print(im_mask.sum().item())
|
| 328 |
inputs = {
|
| 329 |
k: v.to(self.device)
|
| 330 |
for k, v in inputs.items() if torch.is_tensor(v)
|
| 331 |
}
|
| 332 |
-
# print(len(inputs['inputs_embeds'][0]))
|
| 333 |
eos_token_id = [
|
| 334 |
tokenizer.eos_token_id,
|
| 335 |
#tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
|
| 336 |
]
|
| 337 |
all_pred_masks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
outputs = self.generate(
|
| 339 |
**inputs,
|
| 340 |
max_new_tokens=max_new_tokens,
|
| 341 |
im_mask=im_mask,
|
| 342 |
input_ids = None,
|
| 343 |
-
streamer=
|
| 344 |
num_beams=1,
|
| 345 |
do_sample=False,
|
| 346 |
temperature=1.0,
|
|
|
|
| 12 |
from model.IXC.modeling_internlm2 import InternLM2Model
|
| 13 |
from model.sam2.build_sam import build_sam2_hf
|
| 14 |
from model.sam2.utils.transforms import SAM2Transforms
|
| 15 |
+
from transformers import TextStreamer
|
| 16 |
try:
|
| 17 |
from transformers.generation.streamers import BaseStreamer
|
| 18 |
except: # noqa # pylint: disable=bare-except
|
|
|
|
| 94 |
(128, 128),
|
| 95 |
(64, 64),
|
| 96 |
]
|
| 97 |
+
|
| 98 |
for param in self.visual_model.parameters():
|
| 99 |
param.requires_grad = False
|
| 100 |
+
|
| 101 |
if config.train_mask_decoder:
|
| 102 |
self.visual_model.sam_mask_decoder.train()
|
| 103 |
for param in self.visual_model.sam_mask_decoder.parameters():
|
|
|
|
| 198 |
samples = kwargs.get('samples', None)
|
| 199 |
if samples and samples['data_type'][0] == 'grounding':
|
| 200 |
kwargs['output_hidden_states'] = True
|
| 201 |
+
kwargs['use_cache'] = False
|
| 202 |
+
|
| 203 |
torch.cuda.empty_cache()
|
| 204 |
outputs = super().forward(**kwargs)
|
| 205 |
|
|
|
|
| 251 |
low_res_masks,
|
| 252 |
ori_hw[i],
|
| 253 |
)
|
|
|
|
|
|
|
|
|
|
| 254 |
all_pred_masks.append(pred_masks[:, 0])
|
| 255 |
|
| 256 |
|
|
|
|
| 322 |
hd_num: int = 9,
|
| 323 |
history: List[Tuple[str, str]] = [],
|
| 324 |
max_new_tokens: int = 1024,
|
| 325 |
+
stream: bool = False,
|
| 326 |
**kwargs,
|
| 327 |
):
|
| 328 |
with torch.no_grad():
|
| 329 |
inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num)
|
|
|
|
| 330 |
inputs = {
|
| 331 |
k: v.to(self.device)
|
| 332 |
for k, v in inputs.items() if torch.is_tensor(v)
|
| 333 |
}
|
|
|
|
| 334 |
eos_token_id = [
|
| 335 |
tokenizer.eos_token_id,
|
| 336 |
#tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
|
| 337 |
]
|
| 338 |
all_pred_masks = []
|
| 339 |
+
|
| 340 |
+
if stream:
|
| 341 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 342 |
+
else:
|
| 343 |
+
streamer = None
|
| 344 |
+
|
| 345 |
outputs = self.generate(
|
| 346 |
**inputs,
|
| 347 |
max_new_tokens=max_new_tokens,
|
| 348 |
im_mask=im_mask,
|
| 349 |
input_ids = None,
|
| 350 |
+
streamer= streamer,
|
| 351 |
num_beams=1,
|
| 352 |
do_sample=False,
|
| 353 |
temperature=1.0,
|