File size: 1,224 Bytes
03ea185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import PretrainedConfig
import torch

class MIC21SummarizerConfig(PretrainedConfig):
    model_type = "mic21_summarizer"

    def __init__(
        self,
        hf_text_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        hf_image_model = "microsoft/resnet-50",
        im_model_cuda_id = 0,
        device_map = "auto",
        memory_map = {},
        #text_model_dtype = torch.float16,
        attn_implementation = "eager",
        in_device = 0,
        out_device = 0,
        output_length = 40,
        **kwargs,
    ):
        self.hf_text_model = hf_text_model
        self.hf_image_model = hf_image_model
        self.im_model_cuda_id = im_model_cuda_id
        self.device_map = device_map
        self.memory_map = memory_map
        #self.text_model_dtype = text_model_dtype
        self.attn_implementation = attn_implementation
        self.in_device = in_device
        self.out_device = out_device
        self.output_length = output_length
        self.auto_map = {
                "AutoConfig": "jkralev/mic21_model--configuration_mic21.MIC21SummarizerConfig",
                "AutoModel": "jkralev/mic21_model--modeling_mic21.MIC21SummarizerModel"}
        super().__init__(**kwargs)