# modeling_logics.py import os import math import re from typing import List, Optional, Tuple, Union from dataclasses import dataclass from functools import partial import ast from io import BytesIO import base64 import torch from torch import nn import torch.nn.functional as F from .convnext_encoder import ConvNextVisionTower from .siglip_encoder import SigLipVisionTower from PIL import Image from copy import deepcopy import random from .configuration_logics import LogicsConfig from abc import ABC, abstractmethod from transformers import PretrainedConfig, PreTrainedModel, Qwen3Config, Qwen3Model, Qwen3ForCausalLM, CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig, AutoConfig, AutoModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.activations import ACT2FN from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. Args: original_size (tuple): The original size of the image in the format (width, height). possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. Returns: tuple: The best fit resolution in the format (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") for width, height in possible_resolutions: # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) # Calculate effective and wasted resolutions effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (tuple): The size of the input image in the format (width, height). grid_pinpoints (str): A string representation of a list of possible resolutions. patch_size (int): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" # Use regex to extract the range from the input string matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) range_start = tuple(map(int, matches[0])) range_end = tuple(map(int, matches[-1])) grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) width, height = select_best_resolution(image_size, possible_resolutions) return width // patch_size, height // patch_size def unpad_image(tensor, original_size): """ Unpads a PyTorch tensor of a padded and resized image. Args: tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. original_size (tuple): The original size of the image (height, width). Returns: torch.Tensor: The unpadded image tensor. """ original_width, original_height = original_size current_height, current_width = tensor.shape[1:] # Compute aspect ratios original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height # Determine padding size and direction if original_aspect_ratio > current_aspect_ratio: # Padding was added to the height scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding : current_height - padding, :] else: # Padding was added to the width scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor class MultiBackboneChannelConcatenationVisionTower(nn.Module): def __init__(self, vision_config: LogicsConfig, grid_size=27): super().__init__() self.vision_config = vision_config self.vision_tower_name_list = vision_config.mm_vision_tower.replace(";", ",").split(",") self.is_loaded = False self.grid_size = grid_size self.num_tokens = self.grid_size ** 2 self.input_image_size = 384 self.image_size = 384 self.num_patches_per_side = 27 # --- load_vision_tower ------- self.vision_towers = nn.ModuleList() convnext_config = deepcopy(self.vision_config) convnext_config.freeze_vision = False convnext_config.input_image_size = 384 convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, convnext_config) convnext_vision_tower.load_model(gradient_checkpointing=True) self.vision_towers.append(convnext_vision_tower) print("convnext-256 loaded") siglip_vision_tower = "siglip2-so400m-patch14-384" siglip_vision_tower = SigLipVisionTower(siglip_vision_tower, vision_tower_cfg=self.vision_config.siglip_config) siglip_vision_tower.gradient_checkpointing = True siglip_vision_tower.load_model() self.vision_towers.append(siglip_vision_tower) print("siglip2-384 loaded") def forward(self, x): features = [] for vision_tower in self.vision_towers: if vision_tower.input_image_size != self.input_image_size: resized_x = F.interpolate(x.float(), size=(vision_tower.input_image_size, vision_tower.input_image_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) else: resized_x = x feature = vision_tower(resized_x) if len(feature.shape) == 3: # b, n, c b, n, c = feature.shape if n == self.num_tokens: features.append(feature) continue w = h = int(n**0.5) feature = feature.transpose(1,2).reshape(b, c, h, w) else: b, c, h, w = feature.shape if w != self.grid_size: feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) features.append(feature.flatten(2,3).transpose(1,2)) features = torch.cat(features, dim=-1) return features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return next(self.clip_vision_tower.parameters()).dtype @property def device(self): return next(self.clip_vision_tower.parameters()).device @property def config(self): assert NotImplementedError pass @property def hidden_size(self): return sum([_.hidden_size for _ in self.vision_towers]) @property def num_patches(self): return self.num_tokens def build_vision_resampler(config, delay_load=False, **kwargs): resampler = torch.nn.Identity() resampler.config = {"mm_resampler_type": None} return resampler def build_vision_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, "mm_projector_type", "linear") # print(projector_type) mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) class LogicsMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def get_2dPool(self, image_feature, stride=2): height = width = self.get_vision_tower().num_patches_per_side num_frames, num_tokens, num_dim = image_feature.shape image_feature = image_feature.view(num_frames, height, width, -1) image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) if self.config.mm_spatial_pool_mode == "average": image_feature = nn.functional.avg_pool2d(image_feature, stride) elif self.config.mm_spatial_pool_mode == "max": image_feature = nn.functional.max_pool2d(image_feature, stride) elif self.config.mm_spatial_pool_mode == "bilinear": height, weight = image_feature.shape[2:] scaled_shape = [math.ceil(height / stride), math.ceil(weight / stride)] image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') else: raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") image_feature = image_feature.permute(0, 2, 3, 1) image_feature = image_feature.view(num_frames, -1, num_dim) return image_feature def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096) all_videos_or_images_features = [] all_faster_video_features = [] cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride for idx, feat in enumerate(per_videos_or_images_features): feat = self.get_model().mm_projector(feat) faster_video_feature = 0 slower_img_feat = 0 if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1: slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride) if self.config.add_faster_video: cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2 faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride) if slower_img_feat != 0: all_videos_or_images_features.append(slower_img_feat) else: all_videos_or_images_features.append(feat) all_faster_video_features.append(faster_video_feature) return all_videos_or_images_features,all_faster_video_features def add_token_per_grid(self, image_feature): resize_h = int(math.sqrt(image_feature.shape[1])) num_frames = image_feature.shape[0] feature_dim = image_feature.shape[-1] image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) if self.config.add_faster_video: # import pdb; pdb.set_trace() # (3584, 832, 14) -> (3584, 64, 13, 14) image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1) # (3584, 64, 13, 14) -> (64, 13, 14, 3584) image_feature = image_feature.permute(1, 2, 3, 0).contiguous() # (64, 13, 14, 3584) -> (64, 13*14, 3584) image_feature = image_feature.flatten(1, 2) # import pdb; pdb.set_trace() return image_feature # import pdb; pdb.set_trace() image_feature = image_feature.flatten(1, 2).transpose(0, 1) return image_feature def add_token_per_frame(self, image_feature): image_feature = image_feature.permute(2, 0, 1).contiguous() image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) image_feature = image_feature.permute(1, 2, 0).contiguous() return image_feature def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: if type(images) is list: images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] video_idx_in_batch = [] for _ in range(len(modalities)): if modalities[_] == "video": video_idx_in_batch.append(_) images_list = [] for image in images: if image.ndim == 4: images_list.append(image) else: images_list.append(image.unsqueeze(0)) concat_images = torch.cat([image for image in images_list], dim=0) split_sizes = [image.shape[0] for image in images_list] encoded_image_features = self.encode_images(concat_images) # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) # This is a list, each element is [num_images, patch * patch, dim] encoded_image_features = torch.split(encoded_image_features, split_sizes) image_features = [] for idx, image_feat in enumerate(encoded_image_features): if idx in video_idx_in_batch: image_features.append(self.get_2dPool(image_feat)) else: image_features.append(image_feat) # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") if mm_patch_merge_type == "flat": image_features = [x.flatten(0, 1) for x in image_features] elif mm_patch_merge_type.startswith("spatial"): new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_idx in video_idx_in_batch: # video operations if mm_newline_position == "grid": # Grid-wise image_feature = self.add_token_per_grid(image_feature) if self.config.add_faster_video: faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx]) # Add a token for each frame concat_slow_fater_token = [] for _ in range(image_feature.shape[0]): if _ % self.config.faster_token_stride == 0: concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) else: concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) image_feature = torch.cat(concat_slow_fater_token) new_image_features.append(image_feature) elif mm_newline_position == "frame": # Frame-wise image_feature = self.add_token_per_frame(image_feature) new_image_features.append(image_feature.flatten(0, 1)) elif mm_newline_position == "one_token": # one-token image_feature = image_feature.flatten(0, 1) if 'unpad' in mm_patch_merge_type: image_feature = torch.cat(( image_feature, self.model.image_newline[None].to(image_feature.device) ), dim=0) new_image_features.append(image_feature) elif mm_newline_position == "no_token": new_image_features.append(image_feature.flatten(0, 1)) else: raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}") elif image_feature.shape[0] > 1: # multi patches and multi images operations base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = int(math.sqrt(base_image_feature.shape[0])) if "anyres_max" in image_aspect_ratio: matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) if matched_anyres_max_num_patches: max_num_patches = int(matched_anyres_max_num_patches.group(1)) if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: if hasattr(self.get_vision_tower(), "image_size"): vision_tower_image_size = self.get_vision_tower().image_size else: raise ValueError("vision_tower_image_size is not found in the vision tower.") try: num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) except Exception as e: print(f"Error: {e}") num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) else: image_feature = image_feature.view(2, 2, height, width, -1) if "maxpool2x2" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = nn.functional.max_pool2d(image_feature, 2) image_feature = image_feature.flatten(1, 2).transpose(0, 1) elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: unit = image_feature.shape[2] image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) c, h, w = image_feature.shape times = math.sqrt(h * w / (max_num_patches * unit**2)) if times > 1.1: image_feature = image_feature[None] image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) elif "unpad" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) else: image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() image_feature = image_feature.flatten(0, 3) if "nobase" in mm_patch_merge_type: pass else: image_feature = torch.cat((base_image_feature, image_feature), dim=0) new_image_features.append(image_feature) else: # single image operations image_feature = image_feature[0] if "unpad" in mm_patch_merge_type: image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) new_image_features.append(image_feature) image_features = new_image_features else: raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") else: image_features = self.encode_images(images) if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): raise NotImplementedError _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask _input_ids = input_ids input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] new_input_embeds = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: try: cur_image_features = image_features[cur_image_idx] except IndexError: cur_image_features = image_features[cur_image_idx - 1] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] if getattr(self.config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) else: new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None if getattr(self.config, "use_pos_skipping", False) and self.training: position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) split_position = random.randint(0, new_input_embeds.size(1)) left_add = random.randint(0, self.config.pos_skipping_range) right_add = random.randint(left_add, self.config.pos_skipping_range) position_ids[:, :split_position] += left_add position_ids[:, split_position:] += right_add return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels class LogicsQwen3Model(Qwen3Model): config_class = LogicsConfig def __init__(self, config: LogicsConfig): super(LogicsQwen3Model,self).__init__(config) self.config = config self.vision_tower = MultiBackboneChannelConcatenationVisionTower(config) self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.image_newline = nn.Parameter( torch.empty(config.hidden_size, dtype=torch.bfloat16) ) def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if isinstance(vision_tower, list): vision_tower = vision_tower[0] return vision_tower class LogicsForConditionalGeneration(Qwen3ForCausalLM, LogicsMetaForCausalLM): config_class = LogicsConfig def __init__(self, config: LogicsConfig): Qwen3ForCausalLM.__init__(self, config) self.config=config self.config.rope_scaling = None self.model = LogicsQwen3Model(config) self.post_init() print(f"config:{config}") def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, modalities: Optional[List[str]] = ["image"], cache_position=None ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_labels_for_multimodal( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, labels=labels, images=images, modalities=modalities, image_sizes=image_sizes ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) if images is not None: inputs["images"] = images if image_sizes is not None: inputs["image_sizes"] = image_sizes return inputs @torch.no_grad() def generate( self, text_inputs, images_inputs=None, image_sizes=None, modalities=None, position_ids = None, attention_mask = None, ) -> Union[GenerateOutput, torch.LongTensor]: kwargs={'do_sample': self.config.do_sample, 'temperature': self.config.temperature, 'top_p': self.config.top_p, 'num_beams': self.config.num_beams, 'max_new_tokens': self.config.max_new_tokens, 'use_cache': self.config.use_cache, 'repetition_penalty':self.config.repetition_penalty} if images_inputs is not None: (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(text_inputs, position_ids, attention_mask, None, None, images_inputs, modalities, image_sizes=image_sizes) else: inputs_embeds = self.model.embed_tokens(text_inputs) return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)