File size: 7,349 Bytes
b781107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from model_components import Block
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenizer, vocab_size

class DecoderLanguageModel(nn.Module):
    """
    Transformer Decoder Language Model with optional coordinate regression head.
    Processes a combined sequence of embeddings.
    Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS).
    """
    def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS,
                 n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT):
        super().__init__()
        # --- Input Embeddings ---
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(max_context, n_embd)
        self.dropout = nn.Dropout(dropout)

        # --- Transformer Blocks ---
        self.blocks = nn.ModuleList([
            Block(n_embd, num_heads, dropout, is_decoder=True)
            for _ in range(n_layer)
        ])

        # --- Final Layer Norm ---
        self.ln_f = nn.LayerNorm(n_embd)

        # --- Output Heads ---
        # 1. Head for token classification
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        # 2. Head for direct coordinate regression (predicting MAX_POINTS * 2 values)
        self.regression_head = nn.Sequential(
            nn.Linear(n_embd, n_embd // 2),
            nn.GELU(),
            nn.Linear(n_embd // 2, MAX_POINTS * 2), # Output MAX_POINTS * (x, y)
            nn.Sigmoid()                           # Output activation [0, 1]
        )
        # --- End Output Heads ---

        self.n_embd = n_embd
        self.max_context = max_context
        self.token_embedding_table.weight = self.lm_head.weight
        self.apply(self._init_weights)
        print(f"DecoderLanguageModel initialized with {n_layer} layers.")

    def _init_weights(self, module):
        # ... (same as before) ...
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
             torch.nn.init.zeros_(module.bias)
             torch.nn.init.ones_(module.weight)


    def forward(self, combined_embeds, attention_mask=None, targets=None):
        """
        Forward pass for training or inference where loss is calculated.
        Regression output is now handled *outside* this module by VLM.
        """
        # --- Input Validation & Processing ---
        if combined_embeds.ndim != 3:
             raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}")
        B, T, C = combined_embeds.shape
        if T > self.max_context:
            # ... (context truncation logic - same as before) ...
            print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.")
            combined_embeds = combined_embeds[:, -self.max_context:, :]
            if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:]
            if targets is not None: targets = targets[:, -self.max_context:]
            T = self.max_context

        # --- Positional Encoding ---
        pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device)
        pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1)
        pos_emb = self.position_embedding_table(pos) # Shape: (T, C)
        x = combined_embeds + pos_emb.unsqueeze(0)
        x = self.dropout(x)

        # --- Transformer Blocks ---
        for block in self.blocks:
            x = block(x, attention_mask=attention_mask)

        # --- Final Layer Norm ---
        x_norm = self.ln_f(x) # Shape: (B, T, C) - Pass this out for VLM regression head

        # --- Classification Head Output ---
        logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize)

        # --- Classification Loss Calculation ---
        class_loss = None
        if targets is not None:
            # ... (cross_entropy calculation - same as before) ...
            try:
                 class_loss = F.cross_entropy(
                     logits.view(-1, logits.size(-1)),
                     targets.view(-1),
                     ignore_index=-100
                 )
                 if torch.isnan(class_loss):
                      print("Warning: class_loss is NaN.")
                      class_loss = None
            except Exception as e:
                 print(f"Error calculating cross_entropy: {e}")
                 print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")
                 class_loss = None

        # Return logits, class_loss, and the final normalized hidden states
        return logits, class_loss, x_norm

    # --- Generation Method (Example - if needed internally, otherwise VLM handles it) ---
    # If VLM needs this class to perform generation based on token IDs:
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Autoregressive generation based on starting token IDs.
        NOTE: This version doesn't handle combined embeddings directly.
              The VisionLanguageModel should ideally use a method like
              generate_from_embeddings or implement the loop externally.
        """
        self.eval()
        for _ in range(max_new_tokens):
            # --- Context Management ---
            # Crop idx if longer than context length
            idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:]

            # --- Forward Pass ---
            # Get embeddings
            tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C)
            # Get positional embeddings
            pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device)
            pos = pos.clamp(max=self.max_context - 1)
            pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C)
            x = self.dropout(tok_embeds + pos_emb)
            # Pass through blocks (no padding mask needed here as we handle single sequence)
            for block in self.blocks:
                x = block(x, attention_mask=None) # Causal mask is internal to block/head
            # Final layer norm and head for the last token only
            x = self.ln_f(x[:, -1:, :]) # (B, 1, C)
            logits = self.lm_head(x)    # (B, 1, V)
            logits = logits.squeeze(1) # (B, V)

            # --- Sampling ---
            logits = logits / temperature
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

            # Append sampled token
            idx = torch.cat((idx, idx_next), dim=1)

            # Stop if EOS
            if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all():
                break
        self.train()
        return idx