| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch JapaneseStableLMAlpha model. """ |
| import torch |
| from torch import nn |
| from transformers import ( |
| InstructBlipPreTrainedModel, |
| InstructBlipVisionModel, |
| InstructBlipQFormerModel, |
| InstructBlipForConditionalGeneration, |
| AutoModelForCausalLM, |
| AutoModelForSeq2SeqLM, |
| ) |
| from transformers.utils import logging |
| from .modeling_japanese_stablelm_alpha import JapaneseStableLMAlphaForCausalLM |
| from .configuration_japanese_instructblip_alpha import JapaneseInstructBlipAlphaConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class JapaneseInstructBlipAlphaForConditionalGeneration(InstructBlipForConditionalGeneration): |
| config_class = JapaneseInstructBlipAlphaConfig |
|
|
| def __init__(self, config: JapaneseInstructBlipAlphaConfig): |
| InstructBlipPreTrainedModel.__init__(self, config) |
|
|
| self.vision_model = InstructBlipVisionModel(config.vision_config) |
|
|
| self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) |
| self.qformer = InstructBlipQFormerModel(config.qformer_config) |
|
|
| self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) |
|
|
| if config.use_decoder_only_language_model: |
| language_model = JapaneseStableLMAlphaForCausalLM(config.text_config) |
| else: |
| raise NotImplementedError |
| language_model = AutoModelForSeq2SeqLM.from_config(config.text_config, trust_remote_code=True,) |
|
|
| if language_model._no_split_modules is not None: |
| self._no_split_modules.extend(language_model._no_split_modules) |
|
|
| if language_model._keep_in_fp32_modules is not None: |
| self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules) |
|
|
| self.language_model = language_model |
|
|
| |
| self.post_init() |
|
|