torch.compile падает на Tesla T4
#17
by
jor-bored
- opened
Падает, потому что она SM75, а для компиляции bf16 нужно >=SM80.
Но можно решить проще. В коде модели строго все приводится к bf16.
Строка 991: with torch.autocast('cuda', dtype=torch.bfloat16):
Просьба пофиксить, чтобы все красиво работало на T4.
У меня пока работает некрасиво через быструю затычку)))
from transformers.modeling_outputs import BaseModelOutputWithPast
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
return_embeddings: bool = False, **kwargs):
kwargs.pop('token_type_ids', None)
with torch.autocast('cuda', dtype=torch.float16):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
last_hidden = self.latent_attention_model(outputs.last_hidden_state, attention_mask)
if return_embeddings:
return self.mean_pool(last_hidden, attention_mask)
return BaseModelOutputWithPast(last_hidden_state=last_hidden)
model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, trust_remote_code=True)
model.forward = forward.get(model, model.class)