smb-v1-1.7b_DEPRECATED / modeling_smb_v1.py
irsyad-smb's picture
Update modeling_smb_v1.py
d7e48d1 verified
"""
BioPAN Omni Standalone Model for HuggingFace
=============================================
A model-agnostic multimodal LLM that supports multiple backends:
- LLM: Llama, Qwen2, Qwen3, Phi
- Vision: CLIP, SigLIP, DINOv2, MobileViT
- Connector: Identity, Linear, MLP, Resampler
Usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("your-repo", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("your-repo")
# Multimodal
output = model.chat(
prompt="<image>\nDescribe this image.",
image="path/to/image.jpg",
tokenizer=tokenizer
)
# Text-only
output = model.chat(
prompt="What is the capital of France?",
tokenizer=tokenizer
)
"""
import re
import requests
from PIL import Image
from io import BytesIO
from dataclasses import dataclass
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModelForCausalLM,
GenerationMixin,
StoppingCriteria,
# LLM backends
LlamaForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
PhiForCausalLM,
# Vision backends
CLIPVisionModel,
CLIPImageProcessor,
SiglipVisionModel,
SiglipImageProcessor,
Dinov2Model,
MobileViTModel,
AutoImageProcessor,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
# =============================================================================
# CONSTANTS
# =============================================================================
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
# =============================================================================
# BACKEND MAPPINGS
# =============================================================================
LLM_MAPPING = {
"llama": LlamaForCausalLM,
"qwen2": Qwen2ForCausalLM,
"qwen3": Qwen3ForCausalLM, # Qwen3 uses Qwen2 architecture
"phi": PhiForCausalLM,
}
VISION_TOWER_MAPPING = {
"clip": (CLIPVisionModel, CLIPImageProcessor),
"siglip": (SiglipVisionModel, SiglipImageProcessor),
"dinov2": (Dinov2Model, AutoImageProcessor),
"mobilevit": (MobileViTModel, AutoImageProcessor),
}
# =============================================================================
# TEMPLATE CLASSES
# =============================================================================
@dataclass
class LlamaTemplate:
"""Template for Llama/Vicuna models"""
system: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
user_format: str = "USER: {content} "
assistant_format: str = "ASSISTANT: {content}</s>"
image_format: str = "<image>\n{content}"
stop_str: str = "</s>"
def format_chat(self, prompt: str, has_image: bool = False) -> str:
if has_image:
clean_prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, '').strip()
formatted = self.image_format.format(content=clean_prompt)
else:
formatted = prompt
return self.system + self.user_format.format(content=formatted) + "ASSISTANT:"
@dataclass
class Qwen2Template:
"""Template for Qwen2 base models"""
system: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
user_format: str = "USER: {content} "
assistant_format: str = "ASSISTANT: {content}<|endoftext|>"
image_format: str = "<image>\n{content}"
stop_str: str = "<|endoftext|>"
def format_chat(self, prompt: str, has_image: bool = False) -> str:
if has_image:
clean_prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, '').strip()
formatted = self.image_format.format(content=clean_prompt)
else:
formatted = prompt
return self.system + self.user_format.format(content=formatted) + "ASSISTANT:"
@dataclass
class Qwen3Template:
"""Template for Qwen3 base models"""
system: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
user_format: str = "USER: {content} "
assistant_format: str = "ASSISTANT: {content}<|im_end|>"
image_format: str = "<image>\n{content}"
stop_str: str = "<|im_end|>"
def format_chat(self, prompt: str, has_image: bool = False) -> str:
if has_image:
clean_prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, '').strip()
formatted = self.image_format.format(content=clean_prompt)
else:
formatted = prompt
return self.system + self.user_format.format(content=formatted) + "ASSISTANT:"
@dataclass
class PhiTemplate:
"""Template for Phi models"""
system: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
user_format: str = "USER: {content} "
assistant_format: str = "ASSISTANT: {content}<|endoftext|>"
image_format: str = "<image>\n{content}"
stop_str: str = "<|endoftext|>"
def format_chat(self, prompt: str, has_image: bool = False) -> str:
if has_image:
clean_prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, '').strip()
formatted = self.image_format.format(content=clean_prompt)
else:
formatted = prompt
return self.system + self.user_format.format(content=formatted) + "ASSISTANT:"
TEMPLATE_MAPPING = {
"llama": LlamaTemplate,
"qwen2": Qwen2Template,
"qwen3": Qwen3Template,
"phi": PhiTemplate,
}
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def load_image(image_file: str) -> Image.Image:
"""Load image from URL or file path"""
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file, timeout=30)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def expand2square(pil_img: Image.Image, background_color: tuple) -> Image.Image:
"""Expand image to square with padding"""
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images: List[Image.Image], image_processor, model_cfg) -> torch.Tensor:
"""Process images for the vision tower"""
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(
prompt: str,
tokenizer,
image_token_index: int = IMAGE_TOKEN_INDEX,
return_tensors: str = None
):
"""Tokenize prompt with image token placeholders"""
def _insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
class KeywordsStoppingCriteria(StoppingCriteria):
"""Stop generation when specific keywords are generated"""
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(input_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(input_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (input_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
return False
# =============================================================================
# VISION TOWER
# =============================================================================
class VisionTower(nn.Module):
"""Vision encoder that supports multiple backends"""
def __init__(self, config):
super().__init__()
vision_name = config.vision_model_name_or_path.lower()
# Detect vision backend
self.backend = None
for key in VISION_TOWER_MAPPING:
if key in vision_name:
self.backend = key
break
if self.backend is None:
raise ValueError(f"Unsupported vision tower: {vision_name}")
model_class, processor_class = VISION_TOWER_MAPPING[self.backend]
self._vision_tower = model_class(config.vision_config)
self._image_processor = processor_class.from_pretrained(config.vision_config.model_name_or_path)
self.config = config
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
image_features = self._vision_tower(x, output_hidden_states=True)
image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch':
image_features = image_features[:, 1:]
elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch':
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
return image_features
# =============================================================================
# CONNECTORS
# =============================================================================
class IdentityConnector(nn.Module):
"""Identity connector (pass-through)"""
def __init__(self, config):
super().__init__()
self._connector = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._connector(x)
class LinearConnector(nn.Module):
"""Linear projection connector"""
def __init__(self, config):
super().__init__()
self._connector = nn.Linear(config.vision_hidden_size, config.hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._connector(x)
class MLPConnector(nn.Module):
"""MLP connector with configurable depth"""
def __init__(self, config):
super().__init__()
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.connector_type)
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.vision_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
self._connector = nn.Sequential(*modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._connector(x)
class ResamplerConnector(nn.Module):
"""Perceiver resampler connector"""
def __init__(self, config):
super().__init__()
dim = config.hidden_size
depth = config.num_resampler_layers
num_latents = config.num_queries
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.linear = nn.Linear(config.vision_hidden_size, config.hidden_size)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
nn.MultiheadAttention(dim, num_heads=8, batch_first=True),
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
]))
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b = x.shape[0]
x = self.linear(x)
latents = self.latents.unsqueeze(0).expand(b, -1, -1)
for attn, ff in self.layers:
latents = latents + attn(latents, x, x)[0]
latents = latents + ff(latents)
return self.norm(latents)
def build_connector(config):
"""Factory function to build connector based on config"""
connector_type = config.connector_type.lower()
if connector_type == 'identity':
return IdentityConnector(config)
elif connector_type == 'linear':
return LinearConnector(config)
elif 'mlp' in connector_type and 'gelu' in connector_type:
return MLPConnector(config)
elif connector_type == 'resampler':
return ResamplerConnector(config)
else:
raise ValueError(f"Unsupported connector type: {connector_type}")
# =============================================================================
# CONFIGURATION
# =============================================================================
class SMBV1Config(PretrainedConfig):
model_type = "smb_v1"
def __init__(
self,
llm_model_name_or_path: str = '',
tokenizer_name_or_path: str = None,
vision_model_name_or_path: str = '',
connector_type: str = 'identity',
text_config: dict = None,
hidden_size: int = 2048,
vocab_size: int = 32000,
pad_token: str = None,
pad_token_id: int = None,
tokenizer_padding_side: str = 'right',
tokenizer_model_max_length: int = 2048,
vision_config: dict = None,
vision_hidden_size: int = None,
vision_feature_layer: int = -2,
vision_feature_select_strategy: str = 'patch',
image_aspect_ratio: str = 'square',
resampler_hidden_size: int = None,
num_queries: int = None,
num_resampler_layers: int = None,
use_cache: bool = False,
**kwargs
):
self.llm_model_name_or_path = llm_model_name_or_path
self.tokenizer_name_or_path = tokenizer_name_or_path or llm_model_name_or_path
self.vision_model_name_or_path = vision_model_name_or_path
self.connector_type = connector_type
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.pad_token = pad_token
self.pad_token_id = pad_token_id
self.tokenizer_padding_side = tokenizer_padding_side
self.tokenizer_model_max_length = tokenizer_model_max_length
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_aspect_ratio = image_aspect_ratio
self.resampler_hidden_size = resampler_hidden_size
self.num_queries = num_queries
self.num_resampler_layers = num_resampler_layers
self.use_cache = use_cache
# Load nested configs
if text_config is not None:
self.text_config = AutoConfig.for_model(**text_config)
else:
self.text_config = None
if vision_config is not None:
self.vision_config = AutoConfig.for_model(**vision_config)
else:
self.vision_config = None
if self.text_config is not None:
self.hidden_size = getattr(self.text_config, 'hidden_size', hidden_size)
self.vocab_size = getattr(self.text_config, 'vocab_size', vocab_size)
if self.vision_config is not None:
self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', vision_hidden_size)
else:
self.vision_hidden_size = vision_hidden_size
super().__init__(**kwargs)
# =============================================================================
# MAIN MODEL
# =============================================================================
class SMBV1PreTrainedModel(PreTrainedModel):
config_class = SMBV1Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SMBV1VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", 0.02)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
class SMBV1ForConditionalGeneration(SMBV1PreTrainedModel, GenerationMixin):
def __init__(self, config: SMBV1Config):
super().__init__(config)
# Detect LLM backend from text_config
llm_type = config.text_config.model_type.lower()
if llm_type not in LLM_MAPPING:
raise ValueError(f"Unsupported LLM type: {llm_type}. Supported: {list(LLM_MAPPING.keys())}")
llm_class = LLM_MAPPING[llm_type]
self.language_model = llm_class(config.text_config)
# Get template for this LLM type
template_class = TEMPLATE_MAPPING.get(llm_type, LlamaTemplate)
self.template = template_class()
# Vision tower and connector (optional for text-only)
if config.vision_model_name_or_path:
self.vision_tower = VisionTower(config)
self.connector = build_connector(config)
else:
self.vision_tower = None
self.connector = None
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def tie_weights(self, **kwargs):
return self.language_model.tie_weights(**kwargs)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
return model_embeds
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""Encode images through vision tower and connector"""
kwargs = {
'vision_feature_layer': self.config.vision_feature_layer,
'vision_feature_select_strategy': self.config.vision_feature_select_strategy
}
images = images.to(device=self.device, dtype=self.dtype)
image_features = self.vision_tower(images, **kwargs)
image_features = self.connector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
):
"""Prepare inputs by inserting image features into the embedding sequence"""
vision_tower = self.vision_tower
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
image_features = self.encode_images(images)
# Handle dummy tensors
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# Remove padding
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds = self.language_model.get_input_embeddings()(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate to max length
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Pad and stack
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=_attention_mask.dtype if _attention_mask is not None else torch.bool, device=new_labels[0].device)
position_ids = torch.zeros((batch_size, max_len), dtype=torch.long, device=new_labels[0].device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
padding_side = getattr(self.config, 'tokenizer_padding_side', 'right')
if padding_side == "left":
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return self.language_model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.language_model.get_input_embeddings()(inputs)
return self.language_model.generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = self.language_model.prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
def chat(
self,
prompt: str,
tokenizer,
image: str = None,
max_new_tokens: int = 512,
num_beams: int = 1,
top_p: float = None,
temperature: float = 0,
) -> str:
"""
Chat interface for both text-only and multimodal inference.
Args:
prompt: User prompt. Include <image> token for multimodal.
tokenizer: Tokenizer for the model.
image: Path or URL to image (required if <image> in prompt).
max_new_tokens: Maximum tokens to generate.
num_beams: Beam search width.
top_p: Nucleus sampling parameter.
temperature: Sampling temperature.
Returns:
Generated text response.
"""
# Detect if multimodal based on <image> token
has_image = DEFAULT_IMAGE_TOKEN in prompt
if has_image:
if image is None:
raise ValueError("Prompt contains <image> token but no image provided")
if self.vision_tower is None:
raise ValueError("Model has no vision tower but prompt contains <image> token")
# Format prompt using template
formatted_prompt = self.template.format_chat(prompt, has_image)
# Process image if needed
image_tensor = None
if has_image:
pil_image = load_image(image)
image_tensor = process_images([pil_image], self.vision_tower._image_processor, self.config)
if isinstance(image_tensor, list):
image_tensor = torch.stack(image_tensor).to(self.device)
else:
image_tensor = image_tensor.to(self.device)
# Tokenize
input_ids = tokenizer_image_token(formatted_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
input_ids = input_ids.to(self.device)
input_len = input_ids.shape[1]
# Build list of stop token ids (model-agnostic)
eos_token_ids = []
# Add primary eos_token_id
if tokenizer.eos_token_id is not None:
eos_token_ids.append(tokenizer.eos_token_id)
# Add common stop tokens that might exist
stop_tokens = ["<|im_end|>", "<|endoftext|>", "</s>", "<|eot_id|>"]
for token in stop_tokens:
token_id = tokenizer.convert_tokens_to_ids(token)
# Only add if token exists (not UNK) and not already in list
if token_id != tokenizer.unk_token_id and token_id not in eos_token_ids:
eos_token_ids.append(token_id)
# Stopping criteria for string-based stopping
stop_str = self.template.stop_str
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
# Build generation kwargs
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"use_cache": True,
"pad_token_id": tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
"eos_token_id": eos_token_ids if len(eos_token_ids) > 1 else eos_token_ids[0] if eos_token_ids else None,
"stopping_criteria": [stopping_criteria],
}
# Add sampling parameters only if needed
if temperature > 0:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = temperature
if top_p is not None:
gen_kwargs["top_p"] = top_p
else:
gen_kwargs["do_sample"] = False
# Generate
with torch.inference_mode():
output_ids = self.generate(
input_ids,
images=image_tensor,
**gen_kwargs
)
# Decode only the generated tokens (remove prompt)
generated_ids = output_ids[:, input_len:]
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
# Remove stop string if present at end
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)].strip()
# Also remove common stop patterns
for pattern in ["</s>", "<|im_end|>", "<|endoftext|>", "<|eot_id|>"]:
if outputs.endswith(pattern):
outputs = outputs[:-len(pattern)].strip()
return outputs
# =============================================================================
# REGISTER WITH AUTO CLASSES
# =============================================================================
AutoConfig.register("smb_v1", SMBV1Config)
AutoModelForCausalLM.register(SMBV1Config, SMBV1ForConditionalGeneration)