| | from typing import Dict, List, Optional, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | import numpy as np |
| | from dataclasses import dataclass |
| |
|
| | from .transformer import ( |
| | LayerNormFp32, |
| | LayerNorm, |
| | QuickGELU, |
| | MultimodalTransformer, |
| | ) |
| | from .model import CLIPTextCfg, _build_text_tower |
| |
|
| | try: |
| | from transformers import ( |
| | BeamSearchScorer, |
| | LogitsProcessorList, |
| | TopPLogitsWarper, |
| | TopKLogitsWarper, |
| | RepetitionPenaltyLogitsProcessor, |
| | MinLengthLogitsProcessor, |
| | MaxLengthCriteria, |
| | StopStringCriteria, |
| | EosTokenCriteria, |
| | StoppingCriteriaList |
| | ) |
| |
|
| | GENERATION_TYPES = { |
| | "top_k": TopKLogitsWarper, |
| | "top_p": TopPLogitsWarper, |
| | "beam_search": "beam_search" |
| | } |
| | _has_transformers = True |
| | except ImportError as e: |
| | GENERATION_TYPES = { |
| | "top_k": None, |
| | "top_p": None, |
| | "beam_search": "beam_search" |
| | } |
| | _has_transformers = False |
| |
|
| |
|
| | @dataclass |
| | class MultimodalCfg(CLIPTextCfg): |
| | mlp_ratio: int = 4 |
| | dim_head: int = 64 |
| | heads: int = 8 |
| | n_queries: int = 256 |
| | attn_pooler_heads: int = 8 |
| |
|
| |
|
| | def _build_text_decoder_tower( |
| | embed_dim, |
| | multimodal_cfg, |
| | quick_gelu: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | ): |
| | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| | act_layer = QuickGELU if quick_gelu else nn.GELU |
| | norm_layer = ( |
| | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | ) |
| |
|
| | decoder = MultimodalTransformer( |
| | context_length=multimodal_cfg.context_length, |
| | width=multimodal_cfg.width, |
| | heads=multimodal_cfg.heads, |
| | layers=multimodal_cfg.layers, |
| | ls_init_value=multimodal_cfg.ls_init_value, |
| | output_dim=embed_dim, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | ) |
| |
|
| | return decoder |
| |
|
| |
|
| | def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: |
| | if not isinstance(token_id, torch.Tensor): |
| | if isinstance(token_id, int): |
| | token_id = [token_id] |
| | token_id = torch.tensor(token_id, device=device) |
| | return token_id |
| |
|
| |
|
| | class CoCa(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | multimodal_cfg: MultimodalCfg, |
| | text_cfg: CLIPTextCfg, |
| | vision_cfg=None, |
| | quick_gelu: bool = False, |
| | init_logit_scale: float = np.log(1 / 0.07), |
| | init_logit_bias: Optional[float] = None, |
| | nonscalar_logit_scale: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | pad_id: int = 0, |
| | ): |
| | super().__init__() |
| | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| | text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg |
| |
|
| | self.text = _build_text_tower( |
| | embed_dim=embed_dim, |
| | text_cfg=text_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | vocab_size = ( |
| | self.text.vocab_size |
| | if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None |
| | else text_cfg.vocab_size |
| | ) |
| |
|
| | if vision_cfg is not None: |
| | from .model import CLIPVisionCfg, _build_vision_tower |
| | vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg |
| | self.visual = _build_vision_tower( |
| | embed_dim=embed_dim, |
| | vision_cfg=vision_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| | else: |
| | self.visual = None |
| |
|
| | self.text_decoder = _build_text_decoder_tower( |
| | vocab_size, |
| | multimodal_cfg=multimodal_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | lshape = [1] if nonscalar_logit_scale else [] |
| | self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| | if init_logit_bias is not None: |
| | self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| | else: |
| | self.logit_bias = None |
| | self.pad_id = pad_id |
| |
|
| | self.context_length = multimodal_cfg.context_length |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable: bool = True): |
| | self.visual.set_grad_checkpointing(enable) |
| | self.text.set_grad_checkpointing(enable) |
| | self.text_decoder.set_grad_checkpointing(enable) |
| |
|
| | def _encode_image(self, images, normalize: bool = True): |
| | image_latent, tokens_embs = self.visual(images) |
| | image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent |
| | return image_latent, tokens_embs |
| |
|
| | def _encode_text(self, text, normalize: bool = True): |
| | text_latent, token_emb = self.text(text) |
| | text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent |
| | return text_latent, token_emb |
| |
|
| | def encode_image(self, images, normalize: bool = True): |
| | image_latent, _ = self._encode_image(images, normalize=normalize) |
| | return image_latent |
| |
|
| | def encode_text(self, text, normalize: bool = True): |
| | text_latent, _ = self._encode_text(text, normalize=normalize) |
| | return text_latent |
| |
|
| | def forward_intermediates( |
| | self, |
| | image: Optional[torch.Tensor] = None, |
| | text: Optional[torch.Tensor] = None, |
| | image_indices: Optional[Union[int, List[int]]] = None, |
| | text_indices: Optional[Union[int, List[int]]] = None, |
| | stop_early: bool = False, |
| | normalize: bool = True, |
| | normalize_intermediates: bool = False, |
| | intermediates_only: bool = False, |
| | image_output_fmt: str = 'NCHW', |
| | image_output_extra_tokens: bool = False, |
| | text_output_fmt: str = 'NLC', |
| | text_output_extra_tokens: bool = False, |
| | output_logits: bool = False, |
| | output_logit_scale_bias: bool = False, |
| | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: |
| | """ Forward features that returns intermediates. |
| | |
| | Args: |
| | image: Input image tensor |
| | text: Input text tensor |
| | image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence |
| | text_indices: Take last n blocks if int, all if None, select matching indices if sequence |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | normalize: L2 Normalize final image and text features (if present) |
| | normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) |
| | intermediates_only: Only return intermediate features, do not return final features |
| | image_output_fmt: Shape of intermediate image feature outputs |
| | image_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | text_output_fmt: Shape of intermediate text feature outputs |
| | text_output_extra_tokens: Return both prefix and spatial intermediate tokens |
| | output_logits: Include logits in output |
| | output_logit_scale_bias: Include the logit scale bias in the output |
| | Returns: |
| | |
| | """ |
| | output = {} |
| | if intermediates_only: |
| | |
| | normalize = False |
| | output_logits = False |
| | if output_logits: |
| | assert False, 'FIXME, needs implementing' |
| |
|
| | if image is not None: |
| | image_output = self.visual.forward_intermediates( |
| | image, |
| | indices=image_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=image_output_fmt, |
| | output_extra_tokens=image_output_extra_tokens, |
| | ) |
| | if normalize and "image_features" in image_output: |
| | image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) |
| | output.update(image_output) |
| |
|
| | if text is not None: |
| | text_output = self.text.forward_intermediates( |
| | text, |
| | indices=text_indices, |
| | stop_early=stop_early, |
| | normalize_intermediates=normalize_intermediates, |
| | intermediates_only=intermediates_only, |
| | output_fmt=text_output_fmt, |
| | output_extra_tokens=text_output_extra_tokens, |
| | ) |
| | if normalize and "text_features" in text_output: |
| | text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) |
| | output.update(text_output) |
| |
|
| | |
| | logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None |
| | if output_logit_scale_bias: |
| | output["logit_scale"] = logit_scale_exp |
| | if self.logit_bias is not None: |
| | output['logit_bias'] = self.logit_bias |
| |
|
| | return output |
| |
|
| | def forward( |
| | self, |
| | image, |
| | text: Optional[torch.Tensor] = None, |
| | image_latent: Optional[torch.Tensor] = None, |
| | image_embs: Optional[torch.Tensor] = None, |
| | output_labels: bool = True, |
| | ): |
| | if image_latent is None or image_embs is None: |
| | image_latent, image_embs = self._encode_image(image) |
| |
|
| | if text is None: |
| | return {"image_features": image_latent, "image_embs": image_embs} |
| |
|
| | text_latent, token_embs = self._encode_text(text) |
| |
|
| | |
| | labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None |
| | if output_labels: |
| | |
| | token_embs = token_embs[:, :-1] |
| |
|
| | logits = self.text_decoder(image_embs, token_embs) |
| | out_dict = { |
| | "image_features": image_latent, |
| | "text_features": text_latent, |
| | "logits": logits, |
| | "logit_scale": self.logit_scale.exp() |
| | } |
| | if labels is not None: |
| | out_dict["labels"] = labels |
| | if self.logit_bias is not None: |
| | out_dict["logit_bias"] = self.logit_bias |
| | return out_dict |
| |
|
| | def generate( |
| | self, |
| | image, |
| | text=None, |
| | seq_len=30, |
| | max_seq_len=77, |
| | temperature=1., |
| | generation_type="beam_search", |
| | top_p=0.1, |
| | top_k=1, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | sot_token_id=None, |
| | num_beams=6, |
| | num_beam_groups=3, |
| | min_seq_len=5, |
| | stopping_criteria=None, |
| | repetition_penalty=1.0, |
| | fixed_output_length=False |
| | ): |
| | |
| | |
| | assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." |
| | assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" |
| | device = image.device |
| |
|
| | with torch.no_grad(): |
| | sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) |
| | eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) |
| | pad_token_id = self.pad_id if pad_token_id is None else pad_token_id |
| | logit_processor = LogitsProcessorList( |
| | [ |
| | MinLengthLogitsProcessor(min_seq_len, eos_token_id), |
| | RepetitionPenaltyLogitsProcessor(repetition_penalty), |
| | ] |
| | ) |
| |
|
| | if stopping_criteria is None: |
| | stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] |
| | stopping_criteria = StoppingCriteriaList(stopping_criteria) |
| |
|
| | if generation_type == "beam_search": |
| | output = self._generate_beamsearch( |
| | image_inputs=image, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | sot_token_id=sot_token_id, |
| | num_beams=num_beams, |
| | num_beam_groups=num_beam_groups, |
| | min_seq_len=min_seq_len, |
| | stopping_criteria=stopping_criteria, |
| | logit_processor=logit_processor, |
| | ) |
| | if fixed_output_length and output.shape[1] < seq_len: |
| | pad_len = seq_len - output.shape[1] |
| | return torch.cat(( |
| | output, |
| | torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id |
| | ), |
| | dim=1 |
| | ) |
| | return output |
| |
|
| | elif generation_type == "top_p": |
| | logit_warper = GENERATION_TYPES[generation_type](top_p) |
| | elif generation_type == "top_k": |
| | logit_warper = GENERATION_TYPES[generation_type](top_k) |
| | else: |
| | raise ValueError( |
| | f"generation_type has to be one of " |
| | f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." |
| | ) |
| |
|
| | image_latent, image_embs = self._encode_image(image) |
| |
|
| | if text is None: |
| | text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id |
| |
|
| | was_training = self.training |
| | num_dims = len(text.shape) |
| |
|
| | if num_dims == 1: |
| | text = text[None, :] |
| |
|
| | self.eval() |
| | out = text |
| |
|
| | while True: |
| | x = out[:, -max_seq_len:] |
| | cur_len = x.shape[1] |
| | logits = self( |
| | image, |
| | x, |
| | image_latent=image_latent, |
| | image_embs=image_embs, |
| | output_labels=False, |
| | )["logits"][:, -1] |
| | mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) |
| | sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id |
| |
|
| | if mask.all(): |
| | if not fixed_output_length: |
| | break |
| | else: |
| | logits = logits[~mask, :] |
| | filtered_logits = logit_processor(x[~mask, :], logits) |
| | filtered_logits = logit_warper(x[~mask, :], filtered_logits) |
| | probs = F.softmax(filtered_logits / temperature, dim=-1) |
| |
|
| | if (cur_len + 1 == seq_len): |
| | sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id |
| | else: |
| | sample[~mask, :] = torch.multinomial(probs, 1) |
| |
|
| | out = torch.cat((out, sample), dim=-1) |
| |
|
| | cur_len += 1 |
| |
|
| | if all(stopping_criteria(out, None)): |
| | break |
| |
|
| | if num_dims == 1: |
| | out = out.squeeze(0) |
| |
|
| | self.train(was_training) |
| | return out |
| |
|
| | def _generate_beamsearch( |
| | self, |
| | image_inputs, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | sot_token_id=None, |
| | num_beams=6, |
| | num_beam_groups=3, |
| | min_seq_len=5, |
| | stopping_criteria=None, |
| | logit_processor=None, |
| | logit_warper=None, |
| | ): |
| | device = image_inputs.device |
| | batch_size = image_inputs.shape[0] |
| | image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) |
| | image_latent, image_embs = self._encode_image(image_inputs) |
| |
|
| | input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) |
| | input_ids = input_ids * sot_token_id |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=device, |
| | num_beam_groups=num_beam_groups, |
| | ) |
| | |
| | logits_processor = ( |
| | LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) |
| | if logit_processor is None |
| | else logit_processor |
| | ) |
| |
|
| | num_beams = beam_scorer.num_beams |
| | num_beam_groups = beam_scorer.num_beam_groups |
| | num_sub_beams = num_beams // num_beam_groups |
| | batch_size = len(beam_scorer._beam_hyps) // num_beam_groups |
| | batch_beam_size, cur_len = input_ids.shape |
| | beam_indices = None |
| |
|
| | if num_beams * batch_size != batch_beam_size: |
| | raise ValueError( |
| | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| | ) |
| |
|
| | beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
| | |
| | |
| | beam_scores[:, ::num_sub_beams] = 0 |
| | beam_scores = beam_scores.view((batch_size * num_beams,)) |
| |
|
| | while True: |
| |
|
| | |
| | current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
| |
|
| | |
| | reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
| |
|
| | |
| | model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) |
| | outputs = self( |
| | model_inputs['images'], |
| | model_inputs['text'], |
| | image_latent=image_latent, |
| | image_embs=image_embs, |
| | output_labels=False, |
| | ) |
| |
|
| | for beam_group_idx in range(num_beam_groups): |
| | group_start_idx = beam_group_idx * num_sub_beams |
| | group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
| | group_size = group_end_idx - group_start_idx |
| |
|
| | |
| | batch_group_indices = [] |
| |
|
| | for batch_idx in range(batch_size): |
| | batch_group_indices.extend( |
| | [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
| | ) |
| | group_input_ids = input_ids[batch_group_indices] |
| |
|
| | |
| | next_token_logits = outputs['logits'][batch_group_indices, -1, :] |
| | vocab_size = next_token_logits.shape[-1] |
| |
|
| | next_token_scores_processed = logits_processor( |
| | group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
| | ) |
| | next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
| | next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
| |
|
| | |
| | next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
| |
|
| | next_token_scores, next_tokens = torch.topk( |
| | next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
| | ) |
| |
|
| | next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
| | next_tokens = next_tokens % vocab_size |
| |
|
| | |
| | process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| | beam_outputs = beam_scorer.process( |
| | group_input_ids, |
| | next_token_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | beam_indices=process_beam_indices, |
| | group_index=beam_group_idx, |
| | ) |
| | beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
| | beam_next_tokens = beam_outputs["next_beam_tokens"] |
| | beam_idx = beam_outputs["next_beam_indices"] |
| |
|
| | input_ids[batch_group_indices] = group_input_ids[beam_idx] |
| | group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| | current_tokens[batch_group_indices] = group_input_ids[:, -1] |
| |
|
| | |
| | |
| | reordering_indices[batch_group_indices] = ( |
| | num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) |
| | ) |
| |
|
| | input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
| |
|
| | |
| | cur_len = cur_len + 1 |
| | if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): |
| | break |
| |
|
| | final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| | sequence_outputs = beam_scorer.finalize( |
| | input_ids, |
| | beam_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | max_length=stopping_criteria.max_length, |
| | beam_indices=final_beam_indices, |
| | ) |
| | return sequence_outputs['sequences'] |
| |
|
| |
|
| | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): |
| | if past: |
| | input_ids = input_ids[:, -1].unsqueeze(-1) |
| |
|
| | attention_mask = kwargs.get("attention_mask", None) |
| | position_ids = kwargs.get("position_ids", None) |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | else: |
| | position_ids = None |
| | return { |
| | "text": input_ids, |
| | "images": image_inputs, |
| | "past_key_values": past, |
| | "position_ids": position_ids, |
| | "attention_mask": attention_mask, |
| | } |
| |
|