torch.compile падает на Tesla T4

#17
by jor-bored - opened

Падает, потому что она SM75, а для компиляции bf16 нужно >=SM80.
Но можно решить проще. В коде модели строго все приводится к bf16.
Строка 991: with torch.autocast('cuda', dtype=torch.bfloat16):
Просьба пофиксить, чтобы все красиво работало на T4.

У меня пока работает некрасиво через быструю затычку)))
from transformers.modeling_outputs import BaseModelOutputWithPast

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
return_embeddings: bool = False, **kwargs):
kwargs.pop('token_type_ids', None)

with torch.autocast('cuda', dtype=torch.float16):
    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

    last_hidden = self.latent_attention_model(outputs.last_hidden_state, attention_mask)

if return_embeddings:
    return self.mean_pool(last_hidden, attention_mask)

return BaseModelOutputWithPast(last_hidden_state=last_hidden)

model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, trust_remote_code=True)
model.forward = forward.get(model, model.class)

Sign up or log in to comment