Faaz commited on
Commit ·
cdc806e
1
Parent(s): 02eef51
Fix: register LLM as nn.Module submodule so optimizer finds LoRA params
Browse files
src/model/mindi_model.py
CHANGED
|
@@ -111,6 +111,9 @@ class MINDI15(nn.Module):
|
|
| 111 |
# Apply LoRA
|
| 112 |
self.architecture.apply_lora()
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
# 3. Vision encoder (frozen CLIP + trainable projection)
|
| 115 |
self.vision_encoder = VisionEncoder(
|
| 116 |
model_name=clip_model,
|
|
@@ -183,10 +186,8 @@ class MINDI15(nn.Module):
|
|
| 183 |
Returns:
|
| 184 |
Dict with 'loss', 'logits', and optionally 'visual_tokens'.
|
| 185 |
"""
|
| 186 |
-
model = self.architecture.get_model()
|
| 187 |
-
|
| 188 |
# Get text embeddings from the LLM's embedding layer
|
| 189 |
-
text_embeds =
|
| 190 |
|
| 191 |
# Encode vision if image provided
|
| 192 |
visual_tokens = None
|
|
@@ -209,7 +210,7 @@ class MINDI15(nn.Module):
|
|
| 209 |
labels = torch.cat([visual_labels, labels], dim=1)
|
| 210 |
|
| 211 |
# Forward through LLM with embeddings (bypass tokenization)
|
| 212 |
-
outputs =
|
| 213 |
inputs_embeds=fused_embeds,
|
| 214 |
attention_mask=fused_mask,
|
| 215 |
labels=labels,
|
|
@@ -257,8 +258,7 @@ class MINDI15(nn.Module):
|
|
| 257 |
Returns:
|
| 258 |
Generated text string (decoded with MINDI tokenizer).
|
| 259 |
"""
|
| 260 |
-
|
| 261 |
-
model.eval()
|
| 262 |
|
| 263 |
# Tokenize with MINDI tokenizer
|
| 264 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
@@ -267,11 +267,11 @@ class MINDI15(nn.Module):
|
|
| 267 |
|
| 268 |
# If image provided, build fused embeddings
|
| 269 |
if image is not None:
|
| 270 |
-
text_embeds =
|
| 271 |
visual_tokens = self.vision_encoder.encode_image(image)
|
| 272 |
fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
|
| 273 |
|
| 274 |
-
output_ids =
|
| 275 |
inputs_embeds=fused_embeds,
|
| 276 |
attention_mask=fused_mask,
|
| 277 |
max_new_tokens=max_new_tokens,
|
|
@@ -284,7 +284,7 @@ class MINDI15(nn.Module):
|
|
| 284 |
)
|
| 285 |
else:
|
| 286 |
# Text-only generation (direct input_ids)
|
| 287 |
-
output_ids =
|
| 288 |
input_ids=input_ids,
|
| 289 |
attention_mask=attention_mask,
|
| 290 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 111 |
# Apply LoRA
|
| 112 |
self.architecture.apply_lora()
|
| 113 |
|
| 114 |
+
# Register the LLM as a submodule so .parameters() finds it
|
| 115 |
+
self.llm = self.architecture.get_model()
|
| 116 |
+
|
| 117 |
# 3. Vision encoder (frozen CLIP + trainable projection)
|
| 118 |
self.vision_encoder = VisionEncoder(
|
| 119 |
model_name=clip_model,
|
|
|
|
| 186 |
Returns:
|
| 187 |
Dict with 'loss', 'logits', and optionally 'visual_tokens'.
|
| 188 |
"""
|
|
|
|
|
|
|
| 189 |
# Get text embeddings from the LLM's embedding layer
|
| 190 |
+
text_embeds = self.llm.get_input_embeddings()(input_ids)
|
| 191 |
|
| 192 |
# Encode vision if image provided
|
| 193 |
visual_tokens = None
|
|
|
|
| 210 |
labels = torch.cat([visual_labels, labels], dim=1)
|
| 211 |
|
| 212 |
# Forward through LLM with embeddings (bypass tokenization)
|
| 213 |
+
outputs = self.llm(
|
| 214 |
inputs_embeds=fused_embeds,
|
| 215 |
attention_mask=fused_mask,
|
| 216 |
labels=labels,
|
|
|
|
| 258 |
Returns:
|
| 259 |
Generated text string (decoded with MINDI tokenizer).
|
| 260 |
"""
|
| 261 |
+
self.llm.eval()
|
|
|
|
| 262 |
|
| 263 |
# Tokenize with MINDI tokenizer
|
| 264 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
|
|
| 267 |
|
| 268 |
# If image provided, build fused embeddings
|
| 269 |
if image is not None:
|
| 270 |
+
text_embeds = self.llm.get_input_embeddings()(input_ids)
|
| 271 |
visual_tokens = self.vision_encoder.encode_image(image)
|
| 272 |
fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
|
| 273 |
|
| 274 |
+
output_ids = self.llm.generate(
|
| 275 |
inputs_embeds=fused_embeds,
|
| 276 |
attention_mask=fused_mask,
|
| 277 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 284 |
)
|
| 285 |
else:
|
| 286 |
# Text-only generation (direct input_ids)
|
| 287 |
+
output_ids = self.llm.generate(
|
| 288 |
input_ids=input_ids,
|
| 289 |
attention_mask=attention_mask,
|
| 290 |
max_new_tokens=max_new_tokens,
|
src/training/mindi_trainer.py
CHANGED
|
@@ -277,9 +277,8 @@ class MINDITrainer:
|
|
| 277 |
# Optional torch.compile (works on ROCm)
|
| 278 |
if config.use_compile:
|
| 279 |
print("[MINDITrainer] Compiling model with torch.compile() ...")
|
| 280 |
-
self.model.
|
| 281 |
-
|
| 282 |
-
)
|
| 283 |
print("[MINDITrainer] Compilation complete")
|
| 284 |
|
| 285 |
# Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)
|
|
|
|
| 277 |
# Optional torch.compile (works on ROCm)
|
| 278 |
if config.use_compile:
|
| 279 |
print("[MINDITrainer] Compiling model with torch.compile() ...")
|
| 280 |
+
self.model.llm = torch.compile(self.model.llm)
|
| 281 |
+
self.model.architecture.peft_model = self.model.llm
|
|
|
|
| 282 |
print("[MINDITrainer] Compilation complete")
|
| 283 |
|
| 284 |
# Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)
|