ywlee88 commited on
Commit
a4dcd96
·
verified ·
1 Parent(s): 5cdef25

Upload folder using huggingface_hub

Browse files
modeling_safellava.py CHANGED
@@ -8,7 +8,7 @@ SafeLLaVA adds image safety classification capabilities to LLaVA.
8
  """
9
 
10
  # Re-export classes from safellava package for HuggingFace auto_map
11
- from safellava.model.language_model.safe_llava_llama_pool import (
12
  SafetyConfig,
13
  SafeLlavaLlamaForCausalLM,
14
  SafetyCausalLMOutputWithPast,
 
8
  """
9
 
10
  # Re-export classes from safellava package for HuggingFace auto_map
11
+ from safellava.model.language_model.safe_llava_llama import (
12
  SafetyConfig,
13
  SafeLlavaLlamaForCausalLM,
14
  SafetyCausalLMOutputWithPast,
safellava/model/language_model/safe_llava_llama.py CHANGED
@@ -1,10 +1,3 @@
1
- """
2
- Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA
3
- Modified for SafeLLaVA
4
-
5
- Original LLaVA License: Apache License 2.0
6
- """
7
-
8
  from typing import List, Optional, Tuple, Union, Dict
9
 
10
  import torch
@@ -12,15 +5,15 @@ import torch.nn as nn
12
  from transformers import AutoConfig, AutoModelForCausalLM
13
  from transformers.modeling_outputs import CausalLMOutputWithPast
14
 
15
- from safellava.model.language_model.llava_llama import (
16
  LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM
17
  )
18
- from safellava.constants import IMAGE_TOKEN_INDEX
19
 
20
  from dataclasses import dataclass
21
 
22
  import logging
23
- from safellava.utils import setup_simple_logging
24
 
25
  setup_simple_logging()
26
 
@@ -65,7 +58,7 @@ class SafetyMLP(nn.Module):
65
 
66
 
67
  class SafetyConfig(LlavaConfig):
68
- """Safety-aware configuration for pooling version """
69
  model_type = "safe_llava_llama"
70
 
71
  def __init__(
@@ -115,18 +108,19 @@ class SafetyConfig(LlavaConfig):
115
  self.safety_head_hidden_scale = safety_head_hidden_scale
116
  self.pooling_method = pooling_method
117
 
118
- # self.use_img_safety_meta_token = False
 
119
  self.use_txt_safety_meta_token = False
120
  self.use_total_safety_meta_token = False
121
 
122
 
123
  class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
124
  """
125
- SafeLLaVA: A simplified version Uses pooled visual features for safety classification.
 
126
  """
127
 
128
  config_class = SafetyConfig
129
- _keys_to_ignore_on_load_unexpected = [] # Don't ignore img_safety_head weights
130
 
131
  def __init__(self, config: SafetyConfig):
132
  super().__init__(config)
@@ -138,7 +132,7 @@ class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
138
  output_size=len(config.safety_categories),
139
  safety_num_hidden_layers=config.safety_num_hidden_layers
140
  )
141
- logging.info("Created img_safety_head for SafeLLaVA")
142
 
143
  # Store pooling method
144
  self.pooling_method = config.pooling_method
@@ -153,79 +147,6 @@ class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
153
  def get_model(self):
154
  return self.model
155
 
156
- @classmethod
157
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
158
- """
159
- Custom from_pretrained to properly load img_safety_head weights.
160
- """
161
- import os
162
- import torch
163
- from pathlib import Path
164
-
165
- # Load the model normally first
166
- model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
167
-
168
- # List of original LLaVA model names
169
- original_llava_models = [
170
- "liuhaotian/llava-v1.5-7b",
171
- "liuhaotian/llava-v1.5-13b",
172
- ]
173
-
174
- is_original_llava = any(str(pretrained_model_name_or_path).startswith(name) for name in original_llava_models)
175
-
176
- # Load safety head weights for SafeLLaVA models
177
- if not is_original_llava:
178
- logging.info(f"Detected SafeLLaVA model: {pretrained_model_name_or_path}")
179
- model_path = Path(pretrained_model_name_or_path)
180
-
181
- # Handle both local paths and HuggingFace Hub
182
- if not model_path.exists():
183
- # Try HuggingFace cache
184
- from huggingface_hub import snapshot_download
185
- try:
186
- model_path = Path(snapshot_download(repo_id=str(pretrained_model_name_or_path)))
187
- logging.info(f"Downloaded from HuggingFace Hub to: {model_path}")
188
- except Exception as e:
189
- logging.warning(f"Could not download from Hub: {e}")
190
- return model
191
-
192
- if model_path.exists():
193
- # Load safety head weights from safetensors
194
- safetensors_index_path = model_path / "model.safetensors.index.json"
195
- if safetensors_index_path.exists():
196
- logging.info("Loading safety head weights from safetensors...")
197
- from safetensors.torch import load_file
198
- import json
199
-
200
- # Load the index file
201
- with open(safetensors_index_path, 'r') as f:
202
- index_data = json.load(f)
203
-
204
- # Load all safetensors files and collect safety head weights
205
- safety_weights = {}
206
- for weight_map in set(index_data.get('weight_map', {}).values()):
207
- safetensors_file = model_path / weight_map
208
- if safetensors_file.exists():
209
- file_weights = load_file(str(safetensors_file))
210
- # Extract only img_safety_head weights
211
- for key, value in file_weights.items():
212
- if key.startswith('img_safety_head.'):
213
- safety_weights[key] = value
214
-
215
- if safety_weights:
216
- logging.info(f"Found {len(safety_weights)} img_safety_head weights")
217
- # Load the weights
218
- missing_keys, unexpected_keys = model.load_state_dict(safety_weights, strict=False)
219
- logging.info("✅ Safety head weights loaded successfully")
220
- else:
221
- logging.warning("⚠️ No img_safety_head weights found in checkpoint")
222
- else:
223
- logging.warning(f"No safetensors index found at {safetensors_index_path}")
224
- else:
225
- logging.warning(f"Model path does not exist: {model_path}")
226
-
227
- return model
228
-
229
  def get_safety_warning(self, unsafe_categories):
230
  if len(unsafe_categories) == 1:
231
  category_str = f"related to {unsafe_categories[0]}"
@@ -314,6 +235,277 @@ class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
314
  pooled_features = torch.stack(pooled_features, dim=0)
315
  return pooled_features
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  def forward(
318
  self,
319
  input_ids=None,
@@ -332,7 +524,7 @@ class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
332
  **kwargs,
333
  ) -> Union[Tuple, CausalLMOutputWithPast, SafetyCausalLMOutputWithPast]:
334
  """
335
- Forward method for SafeLLaVA.
336
  When do_safety=True, extracts and pools visual tokens for safety classification.
337
  """
338
 
 
 
 
 
 
 
 
 
1
  from typing import List, Optional, Tuple, Union, Dict
2
 
3
  import torch
 
5
  from transformers import AutoConfig, AutoModelForCausalLM
6
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
 
8
+ from llava.model.language_model.llava_llama import (
9
  LlavaConfig, LlavaLlamaModel, LlavaLlamaForCausalLM
10
  )
11
+ from llava.constants import IMAGE_TOKEN_INDEX
12
 
13
  from dataclasses import dataclass
14
 
15
  import logging
16
+ from llava.utils import setup_simple_logging
17
 
18
  setup_simple_logging()
19
 
 
58
 
59
 
60
  class SafetyConfig(LlavaConfig):
61
+ """Safety-aware configuration for pooling version without meta tokens"""
62
  model_type = "safe_llava_llama"
63
 
64
  def __init__(
 
108
  self.safety_head_hidden_scale = safety_head_hidden_scale
109
  self.pooling_method = pooling_method
110
 
111
+ # Pool version doesn't use meta tokens
112
+ self.use_img_safety_meta_token = False
113
  self.use_txt_safety_meta_token = False
114
  self.use_total_safety_meta_token = False
115
 
116
 
117
  class SafeLlavaLlamaForCausalLM(LlavaLlamaForCausalLM):
118
  """
119
+ SafeLLaVA-Pool: A simplified version without meta tokens.
120
+ Pools visual tokens directly for safety classification.
121
  """
122
 
123
  config_class = SafetyConfig
 
124
 
125
  def __init__(self, config: SafetyConfig):
126
  super().__init__(config)
 
132
  output_size=len(config.safety_categories),
133
  safety_num_hidden_layers=config.safety_num_hidden_layers
134
  )
135
+ logging.info("Created img_safety_head for SafeLLaVA-Pool")
136
 
137
  # Store pooling method
138
  self.pooling_method = config.pooling_method
 
147
  def get_model(self):
148
  return self.model
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def get_safety_warning(self, unsafe_categories):
151
  if len(unsafe_categories) == 1:
152
  category_str = f"related to {unsafe_categories[0]}"
 
235
  pooled_features = torch.stack(pooled_features, dim=0)
236
  return pooled_features
237
 
238
+ def compute_gradcam(
239
+ self,
240
+ input_ids=None,
241
+ attention_mask=None,
242
+ images=None,
243
+ image_sizes=None,
244
+ target_class=None,
245
+ use_pre_pooling=False,
246
+ **kwargs,
247
+ ):
248
+ """
249
+ Compute Grad-CAM for the image safety classification.
250
+
251
+ Args:
252
+ input_ids: Input token IDs
253
+ attention_mask: Attention mask
254
+ images: Input images tensor [batch_size, 3, H, W]
255
+ image_sizes: Image sizes
256
+ target_class: Target class index for Grad-CAM. If None, uses the predicted class.
257
+ use_pre_pooling: If True, compute Grad-CAM before pooling for better spatial resolution
258
+
259
+ Returns:
260
+ dict with keys:
261
+ - 'heatmap': Grad-CAM heatmap [batch_size, H_feat, W_feat]
262
+ - 'predicted_class': Predicted class index
263
+ - 'predicted_prob': Probability of predicted class
264
+ - 'class_name': Name of the target class
265
+ """
266
+ if images is None:
267
+ raise ValueError("Images are required for Grad-CAM computation")
268
+
269
+ # Enable gradient computation for images
270
+ # Note: We need to enable train mode for vision tower to compute gradients
271
+ was_training = self.training
272
+ was_vision_training = self.get_vision_tower().training
273
+
274
+ # Set vision tower to train mode to enable gradients
275
+ vision_tower = self.get_vision_tower()
276
+ vision_tower.train()
277
+
278
+ # CRITICAL: Enable gradients for vision tower parameters
279
+ # This is necessary because merged LoRA models might have frozen parameters
280
+ for param in vision_tower.parameters():
281
+ param.requires_grad = True
282
+
283
+ # Note: We keep model in eval mode for other components (dropout, batchnorm)
284
+ # but vision tower is in train mode for gradient computation
285
+
286
+ # Ensure images require grad
287
+ if not images.requires_grad:
288
+ images = images.clone().detach().requires_grad_(True)
289
+
290
+ logging.info(f"Images requires_grad: {images.requires_grad}")
291
+
292
+ # Store activations and gradients for Grad-CAM
293
+ activations = []
294
+ gradients = []
295
+
296
+ def save_gradient(grad):
297
+ """Backward hook to capture gradients"""
298
+ logging.info(f"Gradient hook called! Grad shape: {grad.shape}")
299
+ gradients.append(grad.detach())
300
+
301
+ def forward_hook(module, input, output):
302
+ """Forward hook to save activations and register backward hook"""
303
+ if isinstance(output, tuple):
304
+ activation = output[0]
305
+ else:
306
+ activation = output
307
+
308
+ logging.info(f"Forward hook: activation shape={activation.shape}, requires_grad={activation.requires_grad}")
309
+
310
+ # Register backward hook on the activation tensor itself BEFORE saving
311
+ if activation.requires_grad:
312
+ activation.register_hook(save_gradient)
313
+ logging.info("Registered backward hook on activation")
314
+ else:
315
+ logging.warning("Activation does not require grad, cannot register backward hook!")
316
+
317
+ # Save activation (keep gradient connection for now, will detach later if needed)
318
+ activations.append(activation)
319
+
320
+ # Register hook on vision tower
321
+ vision_tower = self.get_vision_tower()
322
+ if vision_tower is None:
323
+ raise AttributeError("Vision tower not found")
324
+
325
+ hook_handle = vision_tower.register_forward_hook(forward_hook)
326
+
327
+ try:
328
+ # Forward pass - Do normal forward but intercept and modify vision features
329
+ # CRITICAL: Use autograd.enable_grad() to force gradient tracking
330
+
331
+ # Store original vision tower forward
332
+ vision_tower = self.get_vision_tower()
333
+ original_forward = vision_tower.forward
334
+
335
+ # Create a wrapper that forces requires_grad on output
336
+ def forward_with_grad(*args, **kwargs):
337
+ output = original_forward(*args, **kwargs)
338
+ if not output.requires_grad:
339
+ output = output.clone().requires_grad_(True)
340
+ # Register hook on this tensor
341
+ output.register_hook(save_gradient)
342
+ # Save to activations
343
+ activations.append(output)
344
+ return output
345
+
346
+ # Temporarily replace forward
347
+ vision_tower.forward = forward_with_grad
348
+
349
+ try:
350
+ with torch.enable_grad():
351
+ if use_pre_pooling:
352
+ # For pre-pooling Grad-CAM, we need to capture the visual tokens from hidden_states
353
+ # before they are pooled
354
+ pre_pool_activations = []
355
+ pre_pool_gradients = []
356
+
357
+ def save_pre_pool_gradient(grad):
358
+ pre_pool_gradients.append(grad)
359
+
360
+ # Store original pool_visual_tokens method
361
+ original_pool_method = self.pool_visual_tokens
362
+
363
+ # Replace with a wrapper that captures pre-pooling features
364
+ def pool_with_capture(hidden_states, input_ids, images):
365
+ # Extract visual tokens before pooling
366
+ # Visual tokens are typically in the positions where image tokens were
367
+ batch_size = hidden_states.shape[0]
368
+
369
+ # Find image token positions
370
+ # The image token index is -200 by default in LLaVA
371
+ IMAGE_TOKEN_INDEX = -200
372
+ image_token_indices = []
373
+ for batch_idx in range(batch_size):
374
+ image_positions = (input_ids[batch_idx] == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
375
+ if len(image_positions) > 0:
376
+ image_token_indices.append(image_positions)
377
+
378
+ # Extract visual features before pooling
379
+ if len(image_token_indices) > 0:
380
+ visual_features = hidden_states[0, image_token_indices[0]] # [num_patches, hidden_dim]
381
+ visual_features = visual_features.clone().requires_grad_(True)
382
+ pre_pool_activations.append(visual_features)
383
+ visual_features.register_hook(save_pre_pool_gradient)
384
+
385
+ # Call original pooling method
386
+ return original_pool_method(hidden_states, input_ids, images)
387
+
388
+ # Temporarily replace the pooling method
389
+ self.pool_visual_tokens = pool_with_capture
390
+
391
+ # Now do the full forward pass
392
+ outputs = self.forward(
393
+ input_ids=input_ids,
394
+ attention_mask=attention_mask,
395
+ images=images,
396
+ image_sizes=image_sizes,
397
+ do_safety=True,
398
+ return_dict=True,
399
+ **kwargs
400
+ )
401
+
402
+ img_safety_logits = outputs.img_safety_logits
403
+ img_safety_probs = outputs.img_safety_probs
404
+
405
+ if use_pre_pooling:
406
+ # Restore original pooling method
407
+ self.pool_visual_tokens = original_pool_method
408
+ finally:
409
+ # Restore original forward
410
+ vision_tower.forward = original_forward
411
+
412
+ # Get predicted class if not specified
413
+ if target_class is None:
414
+ # Use the class with highest probability
415
+ target_class = img_safety_probs.argmax(dim=-1)
416
+ else:
417
+ # Ensure target_class is a tensor
418
+ if isinstance(target_class, int):
419
+ target_class = torch.tensor([target_class], device=img_safety_probs.device)
420
+
421
+ # Get the logit for the target class
422
+ batch_size = img_safety_probs.shape[0]
423
+ target_logits = img_safety_logits[torch.arange(batch_size), target_class]
424
+
425
+ # Backward pass to compute gradients
426
+ self.zero_grad()
427
+ target_logits.sum().backward()
428
+
429
+ # Choose which activations and gradients to use
430
+ if use_pre_pooling:
431
+ # Use pre-pooling features for better spatial resolution
432
+ if 'pre_pool_activations' not in locals() or len(pre_pool_activations) == 0:
433
+ raise RuntimeError("Failed to capture pre-pooling activations")
434
+ if 'pre_pool_gradients' not in locals() or len(pre_pool_gradients) == 0:
435
+ raise RuntimeError("Failed to capture pre-pooling gradients")
436
+
437
+ # Get the pre-pooling features
438
+ # These have spatial structure: [num_patches, hidden_dim]
439
+ activation = pre_pool_activations[0].detach()
440
+ gradient = pre_pool_gradients[0]
441
+
442
+ # Add batch dimension if needed for consistency
443
+ if activation.dim() == 2:
444
+ activation = activation.unsqueeze(0) # [1, num_patches, hidden_dim]
445
+ gradient = gradient.unsqueeze(0)
446
+ else:
447
+ # Use post-pooling features (original behavior - from vision tower)
448
+ if len(activations) == 0:
449
+ raise RuntimeError("Failed to capture activations")
450
+ if len(gradients) == 0:
451
+ raise RuntimeError("Failed to capture gradients")
452
+
453
+ activation = activations[0].detach() # [batch_size, num_patches, hidden_dim]
454
+ gradient = gradients[0] # [batch_size, num_patches, hidden_dim]
455
+
456
+ # Compute Grad-CAM with correct formula
457
+ # For Vision Transformer: gradients and activations are [batch, num_patches, hidden_dim]
458
+ # Standard Grad-CAM: compute importance by averaging gradients across hidden dimension
459
+ # Then weight the activations
460
+
461
+ # Option 1: Standard Grad-CAM - use gradient magnitude as importance
462
+ # This captures which patches have the strongest gradient signal
463
+ cam = (gradient * activation).sum(dim=-1) # [batch_size, num_patches]
464
+
465
+ # Alternative would be:
466
+ # weights = gradient.mean(dim=1, keepdim=True) # Average across patches
467
+ # cam = (activation * weights).sum(dim=-1)
468
+
469
+ # Apply ReLU (only positive contributions)
470
+ cam = torch.nn.functional.relu(cam)
471
+
472
+ # Reshape to 2D spatial grid
473
+ # CLIP ViT-L/14-336px has 24x24 patches
474
+ num_patches_per_side = int(cam.shape[1] ** 0.5)
475
+ cam = cam.reshape(batch_size, num_patches_per_side, num_patches_per_side)
476
+
477
+ # Normalize to [0, 1]
478
+ for i in range(batch_size):
479
+ cam_min = cam[i].min()
480
+ cam_max = cam[i].max()
481
+ if cam_max > cam_min:
482
+ cam[i] = (cam[i] - cam_min) / (cam_max - cam_min)
483
+
484
+ # Get class names
485
+ if isinstance(target_class, torch.Tensor):
486
+ target_class_idx = target_class[0].item()
487
+ else:
488
+ target_class_idx = target_class
489
+
490
+ class_name = self.config.safety_categories[target_class_idx]
491
+
492
+ return {
493
+ 'heatmap': cam.detach().cpu().numpy(),
494
+ 'predicted_class': target_class.cpu().numpy() if isinstance(target_class, torch.Tensor) else target_class,
495
+ 'predicted_prob': img_safety_probs[torch.arange(batch_size), target_class].detach().cpu().numpy(),
496
+ 'class_name': class_name,
497
+ 'all_probs': img_safety_probs.detach().cpu().numpy()
498
+ }
499
+
500
+ finally:
501
+ # Remove hook
502
+ hook_handle.remove()
503
+ # Restore training state
504
+ if not was_vision_training:
505
+ self.get_vision_tower().eval()
506
+ if was_training:
507
+ self.train()
508
+
509
  def forward(
510
  self,
511
  input_ids=None,
 
524
  **kwargs,
525
  ) -> Union[Tuple, CausalLMOutputWithPast, SafetyCausalLMOutputWithPast]:
526
  """
527
+ Forward method for SafeLLaVA-Pool.
528
  When do_safety=True, extracts and pools visual tokens for safety classification.
529
  """
530