ccloud0525 commited on
Commit
b92e396
·
1 Parent(s): 1669dbd

feat: "first commit"

Browse files
Files changed (2) hide show
  1. modality_connector.py +26 -17
  2. ts_generation_mixin.py +19 -2
modality_connector.py CHANGED
@@ -11,22 +11,27 @@ from .configuration_aurora import AuroraConfig
11
 
12
 
13
  class VisionEncoder(nn.Module):
14
- config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config')
15
  def __init__(self, config: AuroraConfig):
16
  super().__init__()
17
- self.processor = UnifiedImageProcessor(config)
18
- self.model = ViTModel(ViTConfig.from_json_file(os.path.join(self.config_path, 'config.json')))
 
 
 
 
 
 
 
19
  for param in self.model.parameters():
20
  param.requires_grad = False
 
21
  self.hidden_size = self.model.config.hidden_size
22
  self.output_dim = config.hidden_size
23
  self.num_distill = config.num_distill
24
 
25
  self.projection = nn.Linear(self.hidden_size, self.output_dim)
26
-
27
  self.target_vision_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
28
 
29
- # Cross-attention layer
30
  self.cross_vision = nn.TransformerDecoder(
31
  nn.TransformerDecoderLayer(
32
  d_model=config.hidden_size,
@@ -68,16 +73,16 @@ class VisionEncoder(nn.Module):
68
 
69
 
70
  class UnifiedImageProcessor(nn.Module):
71
- config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config')
72
- def __init__(self, config: AuroraConfig):
73
  super().__init__()
74
- # Load ViT preprocessor to get pretrained normalization parameters and target size
75
- self.vit_processor = ViTImageProcessor.from_json_file(os.path.join(self.config_path, 'preprocessor_config.json'))
76
- self.target_size = self.vit_processor.size["height"] # e.g., 224 (default ViT input size)
77
 
78
- # Define resizer for pseudo-images (matches real image target size)
79
- self.pseudo_resizer = Resize((self.target_size, self.target_size))
80
 
 
 
 
 
 
81
  self.token_len = config.token_len
82
 
83
  def process_real_image(self, images):
@@ -107,7 +112,7 @@ class UnifiedImageProcessor(nn.Module):
107
  period = input_length
108
 
109
  padding_length = (period - (input_length %
110
- period)) % period
111
  x_pad = F.pad(x, (padding_length, 0))
112
  x_2d = einops.rearrange(x_pad, 'b (p f) -> b 1 f p', f=period)
113
 
@@ -124,20 +129,24 @@ class UnifiedImageProcessor(nn.Module):
124
 
125
 
126
  class TextEncoder(nn.Module):
127
- config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config')
128
  def __init__(self, config: AuroraConfig):
129
  super().__init__()
130
- self.model = BertModel(BertConfig.from_json_file(os.path.join(self.config_path, 'config.json')))
 
 
 
 
 
 
131
  for param in self.model.parameters():
132
  param.requires_grad = False
 
133
  self.hidden_size = self.model.config.hidden_size
134
  self.output_dim = config.hidden_size
135
  self.num_distill = config.num_distill
136
  self.max_length = 125
137
 
138
  self.projection = nn.Linear(self.hidden_size, self.output_dim)
139
-
140
- # Define learnable target tokens (shape: [num_distill_tokens, hidden_size])
141
  self.target_text_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
142
 
143
  self.cross_text = nn.TransformerDecoder(
 
11
 
12
 
13
  class VisionEncoder(nn.Module):
 
14
  def __init__(self, config: AuroraConfig):
15
  super().__init__()
16
+
17
+ base_dir = os.path.dirname(os.path.abspath(__file__))
18
+ self.config_path = os.path.join(base_dir, "vit_config")
19
+
20
+ self.processor = UnifiedImageProcessor(config, self.config_path)
21
+
22
+ vit_config_file = os.path.join(self.config_path, "config.json")
23
+ self.model = ViTModel(ViTConfig.from_json_file(vit_config_file))
24
+
25
  for param in self.model.parameters():
26
  param.requires_grad = False
27
+
28
  self.hidden_size = self.model.config.hidden_size
29
  self.output_dim = config.hidden_size
30
  self.num_distill = config.num_distill
31
 
32
  self.projection = nn.Linear(self.hidden_size, self.output_dim)
 
33
  self.target_vision_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
34
 
 
35
  self.cross_vision = nn.TransformerDecoder(
36
  nn.TransformerDecoderLayer(
37
  d_model=config.hidden_size,
 
73
 
74
 
75
  class UnifiedImageProcessor(nn.Module):
76
+ def __init__(self, config: AuroraConfig, vit_config_path: str):
 
77
  super().__init__()
 
 
 
78
 
79
+ self.config_path = vit_config_path
 
80
 
81
+ processor_file = os.path.join(self.config_path, "preprocessor_config.json")
82
+ self.vit_processor = ViTImageProcessor.from_json_file(processor_file)
83
+
84
+ self.target_size = self.vit_processor.size["height"]
85
+ self.pseudo_resizer = Resize((self.target_size, self.target_size))
86
  self.token_len = config.token_len
87
 
88
  def process_real_image(self, images):
 
112
  period = input_length
113
 
114
  padding_length = (period - (input_length %
115
+ period)) % period
116
  x_pad = F.pad(x, (padding_length, 0))
117
  x_2d = einops.rearrange(x_pad, 'b (p f) -> b 1 f p', f=period)
118
 
 
129
 
130
 
131
  class TextEncoder(nn.Module):
 
132
  def __init__(self, config: AuroraConfig):
133
  super().__init__()
134
+
135
+ base_dir = os.path.dirname(os.path.abspath(__file__))
136
+ self.config_path = os.path.join(base_dir, "bert_config")
137
+
138
+ bert_config_file = os.path.join(self.config_path, "config.json")
139
+ self.model = BertModel(BertConfig.from_json_file(bert_config_file))
140
+
141
  for param in self.model.parameters():
142
  param.requires_grad = False
143
+
144
  self.hidden_size = self.model.config.hidden_size
145
  self.output_dim = config.hidden_size
146
  self.num_distill = config.num_distill
147
  self.max_length = 125
148
 
149
  self.projection = nn.Linear(self.hidden_size, self.output_dim)
 
 
150
  self.target_text_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
151
 
152
  self.cross_text = nn.TransformerDecoder(
ts_generation_mixin.py CHANGED
@@ -9,7 +9,23 @@ from transformers.utils import ModelOutput
9
 
10
 
11
  class TSGenerationMixin(GenerationMixin):
12
- tokenizer = BertTokenizer.from_pretrained(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config'), local_files_only=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  @torch.no_grad()
15
  def generate(
@@ -95,7 +111,8 @@ class TSGenerationMixin(GenerationMixin):
95
  }
96
 
97
  def _tokenize(self, texts, max_length):
98
- return self.tokenizer(
 
99
  texts,
100
  padding='max_length',
101
  truncation=True,
 
9
 
10
 
11
  class TSGenerationMixin(GenerationMixin):
12
+ _tokenizer = None
13
+
14
+ def _get_tokenizer(self):
15
+ if self._tokenizer is None:
16
+ base_dir = os.path.dirname(os.path.abspath(__file__))
17
+ tokenizer_dir = os.path.join(base_dir, "bert_config")
18
+
19
+ if not os.path.isdir(tokenizer_dir):
20
+ raise FileNotFoundError(
21
+ f"BERT tokenizer directory not found: {tokenizer_dir}"
22
+ )
23
+
24
+ self._tokenizer = BertTokenizer.from_pretrained(
25
+ tokenizer_dir,
26
+ local_files_only=True
27
+ )
28
+ return self._tokenizer
29
 
30
  @torch.no_grad()
31
  def generate(
 
111
  }
112
 
113
  def _tokenize(self, texts, max_length):
114
+ tokenizer = self._get_tokenizer()
115
+ return tokenizer(
116
  texts,
117
  padding='max_length',
118
  truncation=True,