| import torch |
|
|
|
|
| |
| def precompute_freqs_cis(dim: int, |
| end: int, |
| theta: float = 10000.0) -> torch.Tensor: |
| """Precomputes the frequency cis.""" |
| freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
| t = torch.arange(end, device=freqs.device) |
| freqs = torch.outer(t, freqs).float() |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| |
| |
| def google_apply_rotary_emb(x: torch.Tensor, |
| freqs_cis: torch.Tensor) -> torch.Tensor: |
| """Applies the rotary embedding to the query and key tensors.""" |
| x_ = torch.view_as_complex( |
| torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1)) |
| x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) |
| x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) |
| x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1) |
| return x_out |
|
|
|
|
| def llama_apply_rotary_emb(x: torch.Tensor, |
| freqs_cis: torch.Tensor) -> torch.Tensor: |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) |
| return x_out.type_as(x) |
|
|
|
|
| WENET_APPLY_ROTARY_EMB = { |
| 'google': google_apply_rotary_emb, |
| 'llama': llama_apply_rotary_emb, |
| } |
|
|