File size: 4,976 Bytes
67d6ead
 
 
 
384cd73
67d6ead
 
 
 
 
 
 
 
 
 
 
 
 
384cd73
67d6ead
384cd73
67d6ead
 
 
 
 
 
 
 
 
384cd73
 
 
 
67d6ead
 
 
 
 
 
 
384cd73
67d6ead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384cd73
67d6ead
 
 
 
 
 
 
 
 
 
 
 
 
384cd73
 
 
 
67d6ead
384cd73
67d6ead
 
 
 
 
 
384cd73
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_tinygpt import TinyGPTConfig
from transformers.modeling_outputs import CausalLMOutputWithPast # Importante para retorno correto

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        var = torch.mean(x ** 2, dim=-1, keepdim=True)
        return x * torch.rsqrt(var + self.eps) * self.weight

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc_in = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
        self.act = nn.GELU()
        self.fc_out = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
    def forward(self, x):
        return self.fc_out(self.act(self.fc_in(x)))

class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.scale = self.head_dim ** -0.5
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
        self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
    def forward(self, x, mask=None):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            if mask.dim() == 2: mask = mask.unsqueeze(0).unsqueeze(0)
            att = att.masked_fill(mask == 0, float('-inf'))
        att = torch.softmax(att, dim=-1)
        out = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm_1 = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.attn = Attention(config)
        self.norm_2 = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.mlp = MLP(config)
    def forward(self, x, mask=None):
        x = x + self.attn(self.norm_1(x), mask)
        x = x + self.mlp(self.norm_2(x))
        return x

class TinyGPTPreTrainedModel(PreTrainedModel):
    config_class = TinyGPTConfig
    base_model_prefix = "transformer"
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None: torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, std=0.02)

class TinyGPTModel(TinyGPTPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.h = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = RMSNorm(config.hidden_size, config.rms_norm_eps)
    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.shape
        pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device)
        x = self.wte(input_ids) + self.wpe(pos)
        mask = torch.tril(torch.ones((T, T), device=input_ids.device)).view(1, 1, T, T)
        for layer in self.h:
            x = layer(x, mask)
        return self.ln_f(x)

class TinyGPTForCausalLM(TinyGPTPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = TinyGPTModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    
    # AQUI ESTAVA O ERRO! Adicionei **kwargs para engolir return_dict, output_attentions, etc.
    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        hidden = self.transformer(input_ids, attention_mask)
        logits = self.lm_head(hidden)
        
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
        # Retorna objeto padrão do HF para evitar erros de compatibilidade
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
        )

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}