from transformers import PretrainedConfig import torch class MIC21SummarizerConfig(PretrainedConfig): model_type = "mic21_summarizer" def __init__( self, hf_text_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_image_model = "microsoft/resnet-50", im_model_cuda_id = 0, device_map = "auto", memory_map = {}, #text_model_dtype = torch.float16, attn_implementation = "eager", in_device = 0, out_device = 0, output_length = 40, **kwargs, ): self.hf_text_model = hf_text_model self.hf_image_model = hf_image_model self.im_model_cuda_id = im_model_cuda_id self.device_map = device_map self.memory_map = memory_map #self.text_model_dtype = text_model_dtype self.attn_implementation = attn_implementation self.in_device = in_device self.out_device = out_device self.output_length = output_length self.auto_map = { "AutoConfig": "jkralev/mic21_model--configuration_mic21.MIC21SummarizerConfig", "AutoModel": "jkralev/mic21_model--modeling_mic21.MIC21SummarizerModel"} super().__init__(**kwargs)