Faaz commited on
Commit
cdc806e
·
1 Parent(s): 02eef51

Fix: register LLM as nn.Module submodule so optimizer finds LoRA params

Browse files
src/model/mindi_model.py CHANGED
@@ -111,6 +111,9 @@ class MINDI15(nn.Module):
111
  # Apply LoRA
112
  self.architecture.apply_lora()
113
 
 
 
 
114
  # 3. Vision encoder (frozen CLIP + trainable projection)
115
  self.vision_encoder = VisionEncoder(
116
  model_name=clip_model,
@@ -183,10 +186,8 @@ class MINDI15(nn.Module):
183
  Returns:
184
  Dict with 'loss', 'logits', and optionally 'visual_tokens'.
185
  """
186
- model = self.architecture.get_model()
187
-
188
  # Get text embeddings from the LLM's embedding layer
189
- text_embeds = model.get_input_embeddings()(input_ids)
190
 
191
  # Encode vision if image provided
192
  visual_tokens = None
@@ -209,7 +210,7 @@ class MINDI15(nn.Module):
209
  labels = torch.cat([visual_labels, labels], dim=1)
210
 
211
  # Forward through LLM with embeddings (bypass tokenization)
212
- outputs = model(
213
  inputs_embeds=fused_embeds,
214
  attention_mask=fused_mask,
215
  labels=labels,
@@ -257,8 +258,7 @@ class MINDI15(nn.Module):
257
  Returns:
258
  Generated text string (decoded with MINDI tokenizer).
259
  """
260
- model = self.architecture.get_model()
261
- model.eval()
262
 
263
  # Tokenize with MINDI tokenizer
264
  inputs = self.tokenizer(prompt, return_tensors="pt")
@@ -267,11 +267,11 @@ class MINDI15(nn.Module):
267
 
268
  # If image provided, build fused embeddings
269
  if image is not None:
270
- text_embeds = model.get_input_embeddings()(input_ids)
271
  visual_tokens = self.vision_encoder.encode_image(image)
272
  fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
273
 
274
- output_ids = model.generate(
275
  inputs_embeds=fused_embeds,
276
  attention_mask=fused_mask,
277
  max_new_tokens=max_new_tokens,
@@ -284,7 +284,7 @@ class MINDI15(nn.Module):
284
  )
285
  else:
286
  # Text-only generation (direct input_ids)
287
- output_ids = model.generate(
288
  input_ids=input_ids,
289
  attention_mask=attention_mask,
290
  max_new_tokens=max_new_tokens,
 
111
  # Apply LoRA
112
  self.architecture.apply_lora()
113
 
114
+ # Register the LLM as a submodule so .parameters() finds it
115
+ self.llm = self.architecture.get_model()
116
+
117
  # 3. Vision encoder (frozen CLIP + trainable projection)
118
  self.vision_encoder = VisionEncoder(
119
  model_name=clip_model,
 
186
  Returns:
187
  Dict with 'loss', 'logits', and optionally 'visual_tokens'.
188
  """
 
 
189
  # Get text embeddings from the LLM's embedding layer
190
+ text_embeds = self.llm.get_input_embeddings()(input_ids)
191
 
192
  # Encode vision if image provided
193
  visual_tokens = None
 
210
  labels = torch.cat([visual_labels, labels], dim=1)
211
 
212
  # Forward through LLM with embeddings (bypass tokenization)
213
+ outputs = self.llm(
214
  inputs_embeds=fused_embeds,
215
  attention_mask=fused_mask,
216
  labels=labels,
 
258
  Returns:
259
  Generated text string (decoded with MINDI tokenizer).
260
  """
261
+ self.llm.eval()
 
262
 
263
  # Tokenize with MINDI tokenizer
264
  inputs = self.tokenizer(prompt, return_tensors="pt")
 
267
 
268
  # If image provided, build fused embeddings
269
  if image is not None:
270
+ text_embeds = self.llm.get_input_embeddings()(input_ids)
271
  visual_tokens = self.vision_encoder.encode_image(image)
272
  fused_embeds, fused_mask = self.fusion(text_embeds, visual_tokens, attention_mask)
273
 
274
+ output_ids = self.llm.generate(
275
  inputs_embeds=fused_embeds,
276
  attention_mask=fused_mask,
277
  max_new_tokens=max_new_tokens,
 
284
  )
285
  else:
286
  # Text-only generation (direct input_ids)
287
+ output_ids = self.llm.generate(
288
  input_ids=input_ids,
289
  attention_mask=attention_mask,
290
  max_new_tokens=max_new_tokens,
src/training/mindi_trainer.py CHANGED
@@ -277,9 +277,8 @@ class MINDITrainer:
277
  # Optional torch.compile (works on ROCm)
278
  if config.use_compile:
279
  print("[MINDITrainer] Compiling model with torch.compile() ...")
280
- self.model.architecture.peft_model = torch.compile(
281
- self.model.architecture.peft_model
282
- )
283
  print("[MINDITrainer] Compilation complete")
284
 
285
  # Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)
 
277
  # Optional torch.compile (works on ROCm)
278
  if config.use_compile:
279
  print("[MINDITrainer] Compiling model with torch.compile() ...")
280
+ self.model.llm = torch.compile(self.model.llm)
281
+ self.model.architecture.peft_model = self.model.llm
 
282
  print("[MINDITrainer] Compilation complete")
283
 
284
  # Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)