fariasultana commited on
Commit
826f659
·
verified ·
1 Parent(s): 0108694

feat: Add capabilities/vision.py

Browse files
Files changed (1) hide show
  1. capabilities/vision.py +529 -0
capabilities/vision.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Vision Module for MiniMind Max2
3
+ Adapter-based approach using SigLIP/DINOv2 vision encoders.
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Optional, Dict, Any, Tuple, Union
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import math
13
+
14
+
15
+ @dataclass
16
+ class VisionConfig:
17
+ """Configuration for vision adapter."""
18
+ # Vision encoder settings
19
+ vision_encoder: str = "siglip-so400m" # siglip-so400m, dinov2-small, clip-vit-base
20
+ vision_hidden_size: int = 1152 # SigLIP-So400M hidden size
21
+ image_size: int = 384
22
+ patch_size: int = 14
23
+ num_image_tokens: int = 729 # (384/14)^2 = 729 patches
24
+
25
+ # Projector settings
26
+ projector_type: str = "mlp" # mlp, linear, resampler
27
+ projector_hidden_size: int = 2048
28
+ projector_num_layers: int = 2
29
+
30
+ # LLM settings (to match MiniMind)
31
+ llm_hidden_size: int = 1024 # MiniMind hidden size
32
+
33
+ # Training settings
34
+ freeze_vision_encoder: bool = True
35
+ freeze_llm: bool = True
36
+ train_projector_only: bool = True
37
+
38
+ # Special tokens
39
+ image_start_token: str = "<image>"
40
+ image_end_token: str = "</image>"
41
+ image_pad_token: str = "<image_pad>"
42
+
43
+
44
+ class MLPProjector(nn.Module):
45
+ """
46
+ Multi-Layer Perceptron projector for vision-language alignment.
47
+ Maps vision encoder outputs to LLM embedding space.
48
+ """
49
+
50
+ def __init__(self, config: VisionConfig):
51
+ super().__init__()
52
+ self.config = config
53
+
54
+ layers = []
55
+ input_size = config.vision_hidden_size
56
+
57
+ for i in range(config.projector_num_layers):
58
+ if i == config.projector_num_layers - 1:
59
+ # Last layer projects to LLM size
60
+ layers.extend([
61
+ nn.Linear(input_size, config.llm_hidden_size),
62
+ ])
63
+ else:
64
+ # Hidden layers
65
+ layers.extend([
66
+ nn.Linear(input_size, config.projector_hidden_size),
67
+ nn.GELU(),
68
+ nn.LayerNorm(config.projector_hidden_size),
69
+ ])
70
+ input_size = config.projector_hidden_size
71
+
72
+ self.projector = nn.Sequential(*layers)
73
+
74
+ def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Project vision features to LLM space.
77
+
78
+ Args:
79
+ vision_features: [batch, num_patches, vision_hidden_size]
80
+
81
+ Returns:
82
+ Projected features: [batch, num_patches, llm_hidden_size]
83
+ """
84
+ return self.projector(vision_features)
85
+
86
+
87
+ class Resampler(nn.Module):
88
+ """
89
+ Perceiver-style resampler for compressing vision tokens.
90
+ Reduces number of image tokens while preserving information.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ config: VisionConfig,
96
+ num_queries: int = 64,
97
+ num_heads: int = 8,
98
+ num_layers: int = 2,
99
+ ):
100
+ super().__init__()
101
+ self.config = config
102
+ self.num_queries = num_queries
103
+
104
+ # Learnable query tokens
105
+ self.queries = nn.Parameter(torch.randn(1, num_queries, config.llm_hidden_size))
106
+
107
+ # Input projection
108
+ self.input_proj = nn.Linear(config.vision_hidden_size, config.llm_hidden_size)
109
+
110
+ # Cross-attention layers
111
+ self.layers = nn.ModuleList([
112
+ nn.TransformerDecoderLayer(
113
+ d_model=config.llm_hidden_size,
114
+ nhead=num_heads,
115
+ dim_feedforward=config.llm_hidden_size * 4,
116
+ batch_first=True,
117
+ )
118
+ for _ in range(num_layers)
119
+ ])
120
+
121
+ self.norm = nn.LayerNorm(config.llm_hidden_size)
122
+
123
+ def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ Resample vision features using learned queries.
126
+
127
+ Args:
128
+ vision_features: [batch, num_patches, vision_hidden_size]
129
+
130
+ Returns:
131
+ Resampled features: [batch, num_queries, llm_hidden_size]
132
+ """
133
+ batch_size = vision_features.shape[0]
134
+
135
+ # Project vision features
136
+ vision_features = self.input_proj(vision_features)
137
+
138
+ # Expand queries for batch
139
+ queries = self.queries.expand(batch_size, -1, -1)
140
+
141
+ # Cross-attend to vision features
142
+ for layer in self.layers:
143
+ queries = layer(queries, vision_features)
144
+
145
+ return self.norm(queries)
146
+
147
+
148
+ class VisionEncoder(nn.Module):
149
+ """
150
+ Wrapper for pre-trained vision encoders.
151
+ Supports SigLIP, DINOv2, and CLIP.
152
+ """
153
+
154
+ def __init__(self, config: VisionConfig):
155
+ super().__init__()
156
+ self.config = config
157
+ self.encoder = None
158
+ self.processor = None
159
+
160
+ # Placeholder for actual encoder loading
161
+ # In practice, load from HuggingFace
162
+ self._build_dummy_encoder()
163
+
164
+ def _build_dummy_encoder(self):
165
+ """Build a dummy encoder for testing."""
166
+ # Simple ViT-like encoder
167
+ patch_dim = 3 * (self.config.patch_size ** 2)
168
+ num_patches = (self.config.image_size // self.config.patch_size) ** 2
169
+
170
+ self.patch_embed = nn.Linear(patch_dim, self.config.vision_hidden_size)
171
+ self.pos_embed = nn.Parameter(
172
+ torch.randn(1, num_patches + 1, self.config.vision_hidden_size) * 0.02
173
+ )
174
+ self.cls_token = nn.Parameter(
175
+ torch.randn(1, 1, self.config.vision_hidden_size) * 0.02
176
+ )
177
+
178
+ # Transformer layers
179
+ self.layers = nn.ModuleList([
180
+ nn.TransformerEncoderLayer(
181
+ d_model=self.config.vision_hidden_size,
182
+ nhead=8,
183
+ dim_feedforward=self.config.vision_hidden_size * 4,
184
+ batch_first=True,
185
+ )
186
+ for _ in range(6)
187
+ ])
188
+ self.norm = nn.LayerNorm(self.config.vision_hidden_size)
189
+
190
+ def patchify(self, images: torch.Tensor) -> torch.Tensor:
191
+ """Convert images to patches."""
192
+ batch_size, c, h, w = images.shape
193
+ p = self.config.patch_size
194
+
195
+ # [B, C, H, W] -> [B, num_patches, patch_dim]
196
+ patches = images.unfold(2, p, p).unfold(3, p, p)
197
+ patches = patches.contiguous().view(batch_size, c, -1, p, p)
198
+ patches = patches.permute(0, 2, 1, 3, 4).contiguous()
199
+ patches = patches.view(batch_size, -1, c * p * p)
200
+
201
+ return patches
202
+
203
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
204
+ """
205
+ Encode images to feature vectors.
206
+
207
+ Args:
208
+ images: [batch, 3, height, width] normalized images
209
+
210
+ Returns:
211
+ Vision features: [batch, num_patches, vision_hidden_size]
212
+ """
213
+ batch_size = images.shape[0]
214
+
215
+ # Patchify and embed
216
+ patches = self.patchify(images)
217
+ x = self.patch_embed(patches)
218
+
219
+ # Add CLS token
220
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
221
+ x = torch.cat([cls_tokens, x], dim=1)
222
+
223
+ # Add positional embeddings
224
+ x = x + self.pos_embed[:, :x.shape[1], :]
225
+
226
+ # Transformer
227
+ for layer in self.layers:
228
+ x = layer(x)
229
+
230
+ x = self.norm(x)
231
+
232
+ # Return patch features (exclude CLS)
233
+ return x[:, 1:, :]
234
+
235
+ @classmethod
236
+ def from_pretrained(cls, model_name: str, config: VisionConfig) -> "VisionEncoder":
237
+ """Load pre-trained vision encoder."""
238
+ encoder = cls(config)
239
+
240
+ # In practice, load weights from HuggingFace
241
+ # try:
242
+ # from transformers import SiglipVisionModel, AutoProcessor
243
+ # encoder.encoder = SiglipVisionModel.from_pretrained(model_name)
244
+ # encoder.processor = AutoProcessor.from_pretrained(model_name)
245
+ # except ImportError:
246
+ # pass
247
+
248
+ return encoder
249
+
250
+
251
+ class VisionAdapter(nn.Module):
252
+ """
253
+ Complete vision adapter for MiniMind Max2.
254
+ Connects vision encoder to LLM via projector.
255
+ """
256
+
257
+ def __init__(self, config: VisionConfig):
258
+ super().__init__()
259
+ self.config = config
260
+
261
+ # Vision encoder
262
+ self.vision_encoder = VisionEncoder(config)
263
+
264
+ # Projector
265
+ if config.projector_type == "mlp":
266
+ self.projector = MLPProjector(config)
267
+ elif config.projector_type == "resampler":
268
+ self.projector = Resampler(config)
269
+ else:
270
+ self.projector = nn.Linear(config.vision_hidden_size, config.llm_hidden_size)
271
+
272
+ # Freeze components as needed
273
+ if config.freeze_vision_encoder:
274
+ for param in self.vision_encoder.parameters():
275
+ param.requires_grad = False
276
+
277
+ def forward(
278
+ self,
279
+ images: torch.Tensor,
280
+ return_features: bool = False,
281
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
282
+ """
283
+ Process images and project to LLM space.
284
+
285
+ Args:
286
+ images: [batch, 3, height, width]
287
+ return_features: Also return raw vision features
288
+
289
+ Returns:
290
+ Projected features: [batch, num_tokens, llm_hidden_size]
291
+ """
292
+ # Encode images
293
+ vision_features = self.vision_encoder(images)
294
+
295
+ # Project to LLM space
296
+ projected = self.projector(vision_features)
297
+
298
+ if return_features:
299
+ return projected, vision_features
300
+ return projected
301
+
302
+ def get_num_image_tokens(self) -> int:
303
+ """Get number of tokens per image."""
304
+ if isinstance(self.projector, Resampler):
305
+ return self.projector.num_queries
306
+ return self.config.num_image_tokens
307
+
308
+
309
+ class MiniMindVision(nn.Module):
310
+ """
311
+ Complete vision-language model combining MiniMind Max2 with vision adapter.
312
+ """
313
+
314
+ def __init__(
315
+ self,
316
+ llm_model: nn.Module,
317
+ vision_config: Optional[VisionConfig] = None,
318
+ ):
319
+ super().__init__()
320
+
321
+ # Get LLM config
322
+ if hasattr(llm_model, 'config'):
323
+ llm_hidden_size = llm_model.config.hidden_size
324
+ else:
325
+ llm_hidden_size = 1024
326
+
327
+ # Vision config
328
+ self.vision_config = vision_config or VisionConfig(llm_hidden_size=llm_hidden_size)
329
+
330
+ # Components
331
+ self.llm = llm_model
332
+ self.vision_adapter = VisionAdapter(self.vision_config)
333
+
334
+ # Freeze LLM if needed
335
+ if self.vision_config.freeze_llm:
336
+ for param in self.llm.parameters():
337
+ param.requires_grad = False
338
+
339
+ def merge_vision_text_embeddings(
340
+ self,
341
+ text_embeddings: torch.Tensor,
342
+ vision_embeddings: torch.Tensor,
343
+ image_positions: torch.Tensor,
344
+ ) -> torch.Tensor:
345
+ """
346
+ Merge vision embeddings into text embedding sequence.
347
+
348
+ Args:
349
+ text_embeddings: [batch, text_seq_len, hidden_size]
350
+ vision_embeddings: [batch, num_image_tokens, hidden_size]
351
+ image_positions: [batch] position indices for image tokens
352
+
353
+ Returns:
354
+ Merged embeddings: [batch, total_seq_len, hidden_size]
355
+ """
356
+ batch_size = text_embeddings.shape[0]
357
+ num_image_tokens = vision_embeddings.shape[1]
358
+
359
+ # Calculate output sequence length
360
+ text_len = text_embeddings.shape[1]
361
+ total_len = text_len + num_image_tokens
362
+
363
+ # Create output tensor
364
+ merged = torch.zeros(
365
+ batch_size, total_len, text_embeddings.shape[-1],
366
+ device=text_embeddings.device,
367
+ dtype=text_embeddings.dtype,
368
+ )
369
+
370
+ for i in range(batch_size):
371
+ pos = image_positions[i].item()
372
+
373
+ # Text before image
374
+ if pos > 0:
375
+ merged[i, :pos] = text_embeddings[i, :pos]
376
+
377
+ # Image tokens
378
+ merged[i, pos:pos + num_image_tokens] = vision_embeddings[i]
379
+
380
+ # Text after image
381
+ if pos < text_len:
382
+ merged[i, pos + num_image_tokens:] = text_embeddings[i, pos:]
383
+
384
+ return merged
385
+
386
+ def forward(
387
+ self,
388
+ input_ids: torch.LongTensor,
389
+ images: Optional[torch.Tensor] = None,
390
+ image_positions: Optional[torch.Tensor] = None,
391
+ attention_mask: Optional[torch.Tensor] = None,
392
+ labels: Optional[torch.LongTensor] = None,
393
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
394
+ """
395
+ Forward pass with optional images.
396
+
397
+ Args:
398
+ input_ids: Text token IDs
399
+ images: Optional batch of images
400
+ image_positions: Where to insert image tokens
401
+ attention_mask: Attention mask for text
402
+ labels: Labels for language modeling
403
+
404
+ Returns:
405
+ Loss (if labels provided) and logits
406
+ """
407
+ # Get text embeddings from LLM
408
+ if hasattr(self.llm, 'model'):
409
+ text_embeddings = self.llm.model.embed_tokens(input_ids)
410
+ else:
411
+ text_embeddings = self.llm.embed_tokens(input_ids)
412
+
413
+ # Process images if provided
414
+ if images is not None:
415
+ vision_embeddings = self.vision_adapter(images)
416
+
417
+ if image_positions is None:
418
+ # Default: insert at beginning
419
+ image_positions = torch.zeros(images.shape[0], dtype=torch.long, device=images.device)
420
+
421
+ # Merge embeddings
422
+ merged_embeddings = self.merge_vision_text_embeddings(
423
+ text_embeddings, vision_embeddings, image_positions
424
+ )
425
+
426
+ # Update attention mask
427
+ if attention_mask is not None:
428
+ num_image_tokens = vision_embeddings.shape[1]
429
+ image_mask = torch.ones(
430
+ images.shape[0], num_image_tokens,
431
+ device=attention_mask.device,
432
+ dtype=attention_mask.dtype,
433
+ )
434
+ attention_mask = torch.cat([image_mask, attention_mask], dim=1)
435
+ else:
436
+ merged_embeddings = text_embeddings
437
+
438
+ # Forward through LLM (need to modify to accept embeddings directly)
439
+ # This is a simplified version
440
+ loss, logits, _, _ = self.llm(
441
+ input_ids=input_ids,
442
+ attention_mask=attention_mask,
443
+ labels=labels,
444
+ )
445
+
446
+ return loss, logits
447
+
448
+ @torch.no_grad()
449
+ def caption_image(
450
+ self,
451
+ image: torch.Tensor,
452
+ prompt: str = "Describe this image:",
453
+ max_new_tokens: int = 100,
454
+ tokenizer = None,
455
+ ) -> str:
456
+ """Generate caption for an image."""
457
+ self.eval()
458
+
459
+ # Encode image
460
+ vision_embeddings = self.vision_adapter(image.unsqueeze(0))
461
+
462
+ # Tokenize prompt
463
+ if tokenizer is not None:
464
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(image.device)
465
+ else:
466
+ # Dummy for testing
467
+ input_ids = torch.randint(0, 1000, (1, 10), device=image.device)
468
+
469
+ # Generate (simplified)
470
+ # In practice, would use the merged embeddings
471
+ generated = self.llm.generate(
472
+ input_ids,
473
+ max_new_tokens=max_new_tokens,
474
+ )
475
+
476
+ if tokenizer is not None:
477
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
478
+ return "Generated caption placeholder"
479
+
480
+
481
+ class VisionDataset(Dataset):
482
+ """Dataset for vision-language training."""
483
+
484
+ def __init__(
485
+ self,
486
+ data_path: str,
487
+ tokenizer,
488
+ image_processor,
489
+ max_length: int = 512,
490
+ ):
491
+ self.tokenizer = tokenizer
492
+ self.image_processor = image_processor
493
+ self.max_length = max_length
494
+ self.examples = []
495
+
496
+ # Load data (e.g., LLaVA-150k format)
497
+ import json
498
+ with open(data_path, 'r') as f:
499
+ self.examples = json.load(f)
500
+
501
+ def __len__(self) -> int:
502
+ return len(self.examples)
503
+
504
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
505
+ example = self.examples[idx]
506
+
507
+ # Load and process image
508
+ # In practice: image = Image.open(example["image"]).convert("RGB")
509
+ # image = self.image_processor(image)
510
+
511
+ # Dummy image for now
512
+ image = torch.randn(3, 384, 384)
513
+
514
+ # Tokenize text
515
+ text = example.get("conversations", [{"value": "Describe the image."}])[0]["value"]
516
+ encodings = self.tokenizer(
517
+ text,
518
+ max_length=self.max_length,
519
+ truncation=True,
520
+ padding="max_length",
521
+ return_tensors="pt",
522
+ )
523
+
524
+ return {
525
+ "image": image,
526
+ "input_ids": encodings["input_ids"].squeeze(0),
527
+ "attention_mask": encodings["attention_mask"].squeeze(0),
528
+ "labels": encodings["input_ids"].squeeze(0),
529
+ }