litvit5 commited on
Commit
244d6df
·
1 Parent(s): 1dde7ac
Files changed (5) hide show
  1. README.md +123 -0
  2. config.json +14 -0
  3. config.py +23 -0
  4. modeling_litevit5.py +285 -0
  5. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - HuggingFaceM4/WebSight
4
+ base_model:
5
+ - Salesforce/codet5-base
6
+ - google/siglip2-base-patch16-512
7
+ ---
8
+
9
+ # LiteVit5 - Image-to-HTML Model
10
+
11
+ A lightweight transformer model combining SigLIP vision encoder with T5 seq2seq decoder for image-to-text generation tasks.
12
+
13
+ ## Model Architecture
14
+
15
+ - **Vision Encoder**: SigLIP2 (frozen)
16
+ - **Vision Processing**: Multi-view fusion
17
+ - **Seq2Seq Decoder**: CodeT5-based decoder with language modeling head
18
+ - **Input**: Images (5 views per sample - 4 quarter views + 1 full view)
19
+ - **Output**: Generated HTML
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ uv add transformers torch accelerate
25
+ ```
26
+
27
+ ## Usage
28
+
29
+ ### Loading the Model
30
+
31
+ ```python
32
+ from transformers import AutoModel, AutoTokenizer
33
+ from transformers import SiglipProcessor
34
+
35
+ # Load the model
36
+ model = AutoModel.from_pretrained("LiteVit5/model", trust_remote_code=True)
37
+
38
+ # Load tokenizer and processor
39
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
40
+ processor = SiglipProcessor.from_pretrained("google/siglip2-base-patch16-512")
41
+ ```
42
+
43
+ ### Inference Example
44
+
45
+ ```python
46
+ from PIL import Image
47
+ import torch
48
+
49
+ from transformers import AutoModel, AutoTokenizer
50
+ from transformers import SiglipProcessor
51
+
52
+ # Load the model
53
+ model = AutoModel.from_pretrained("LiteVit5/model", trust_remote_code=True, device_map="auto")
54
+
55
+ # Load tokenizer and processor
56
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
57
+ processor = SiglipProcessor.from_pretrained("google/siglip2-base-patch16-512")
58
+
59
+ # Preprocess image (split into 4 parts + full image = 5 views)
60
+ def prepare_image(image_path: str, processor):
61
+ """
62
+ Prepare image with 5 views (4 quarters + full).
63
+
64
+ Args:
65
+ image_path: Path to the image file
66
+ processor: SigLIP processor
67
+
68
+ Returns:
69
+ Tensor of shape [5, 3, 512, 512]
70
+ """
71
+ image = Image.open(image_path).convert("RGB")
72
+
73
+ # Split into 4 quarters
74
+ width, height = image.size
75
+ quarters = [
76
+ image.crop((0, 0, width//2, height//2)), # top-left
77
+ image.crop((width//2, 0, width, height//2)), # top-right
78
+ image.crop((0, height//2, width//2, height)), # bottom-left
79
+ image.crop((width//2, height//2, width, height)), # bottom-right
80
+ ]
81
+
82
+ # Process all views
83
+ processed = [
84
+ processor(images=q, return_tensors="pt")["pixel_values"]
85
+ for q in quarters
86
+ ]
87
+ # Add full image
88
+ processed.append(
89
+ processor(images=image, return_tensors="pt")["pixel_values"]
90
+ )
91
+
92
+ pixel_values = torch.cat(processed, dim=0)
93
+ return pixel_values
94
+
95
+ def generate_text(model, pixel_values, tokenizer, max_length=512):
96
+ """
97
+ Generate text from image.
98
+
99
+ Args:
100
+ model: LiteVit5 model
101
+ pixel_values: Preprocessed image tensor
102
+ tokenizer: Tokenizer for decoding
103
+ max_length: Maximum generation length
104
+
105
+ Returns:
106
+ Generated text string
107
+ """
108
+ with torch.no_grad():
109
+ output_ids = model.generate(pixel_values, max_length=max_length)
110
+
111
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
112
+ return text
113
+
114
+ device = next(model.parameters()).device
115
+
116
+ # Process images
117
+ pixel_values = prepare_image("./image_13.png", processor)
118
+ pixel_values = pixel_values.to(device)
119
+ print("\nGenerating HTML from image_13.png...")
120
+ output_ids = model.generate(pixel_values, max_length=2024)
121
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
122
+ print(f"Generated: {text}")
123
+ ```
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["LiteVit5ForConditionalGeneration"],
3
+ "model_type": "litevit5",
4
+ "pad_token_id": 0,
5
+ "eos_token_id": 2,
6
+ "decoder_start_token_id": 0,
7
+ "torch_dtype": "float16",
8
+ "transformers_version": "4.57.3",
9
+ "auto_map": {
10
+ "AutoConfig": "config.LiteVit5Config",
11
+ "AutoModel": "modeling_litevit5.LiteVit5ForConditionalGeneration",
12
+ "AutoModelForSeq2SeqLM": "modeling_litevit5.LiteVit5ForConditionalGeneration"
13
+ }
14
+ }
config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LiteVit5Config(PretrainedConfig):
5
+ """
6
+ Configuration class for LiteVit5ForConditionalGeneration.
7
+ """
8
+
9
+ model_type = "litevit5"
10
+
11
+ def __init__(
12
+ self,
13
+ pad_token_id: int = 0,
14
+ eos_token_id: int = 1,
15
+ decoder_start_token_id: int = 0,
16
+ **kwargs
17
+ ):
18
+ super().__init__(
19
+ pad_token_id=pad_token_id,
20
+ eos_token_id=eos_token_id,
21
+ decoder_start_token_id=decoder_start_token_id,
22
+ **kwargs
23
+ )
modeling_litevit5.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, SiglipVisionModel
6
+ from transformers.modeling_outputs import Seq2SeqLMOutput
7
+ from .config import LiteVit5Config
8
+
9
+
10
+ class LiteVit5ForConditionalGeneration(PreTrainedModel):
11
+ """
12
+ LiteVit5 model for vision-to-text generation tasks.
13
+ Combines SigLIP vision encoder with T5 seq2seq decoder for image-to-text tasks.
14
+ """
15
+
16
+ config_class = LiteVit5Config
17
+ base_model_prefix = "litevit5"
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.config = config
22
+
23
+ # Vision model (frozen)
24
+ self.vision_model = SiglipVisionModel.from_pretrained(
25
+ "google/siglip2-base-patch16-512",
26
+ dtype=torch.float16
27
+ )
28
+ self.vision_model.eval()
29
+ for param in self.vision_model.parameters():
30
+ param.requires_grad = False
31
+
32
+ # Load seq2seq decoder and lm_head from CodeT5
33
+ seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained(
34
+ "Salesforce/codet5-base",
35
+ dtype=torch.float16
36
+ )
37
+ self.seq2seq_decoder = seq2seq_model.decoder
38
+ self.seq2seq_lm_head = seq2seq_model.lm_head
39
+ self._shift_right = seq2seq_model._shift_right
40
+
41
+ # Vision processing layers
42
+ self.downsampler = nn.Conv2d(768, 768, kernel_size=2, stride=2, bias=False, dtype=torch.float16)
43
+ self.fuse = nn.Linear(768 * 2, 768).half()
44
+ self.pos_embedding = nn.Parameter(torch.zeros(1, 1024, 768, dtype=torch.float16), requires_grad=True)
45
+ self.linear_projection = nn.Linear(768, 768).half()
46
+
47
+ self.post_init()
48
+
49
+ def get_encoder(self):
50
+ """Return the vision encoder for the model."""
51
+ return self.vision_model
52
+
53
+ def get_decoder(self):
54
+ """Return the seq2seq decoder."""
55
+ return self.seq2seq_decoder
56
+
57
+ def _encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor:
58
+ """
59
+ Encode image inputs into vision features.
60
+
61
+ Args:
62
+ pixel_values: Input images of shape [B*5, 3, 512, 512] (5 views per sample)
63
+
64
+ Returns:
65
+ Encoded vision features of shape [B, 1024, 768]
66
+ """
67
+ # Ensure pixel_values are float16
68
+ pixel_values = pixel_values.half()
69
+
70
+ batch_size = pixel_values.size(0) // 5
71
+ scale = 5 # Number of views (4 quarter views + 1 full view)
72
+ num_patches = 32
73
+
74
+ # Get vision embeddings
75
+ with torch.no_grad():
76
+ vision_model_outputs = self.vision_model(pixel_values=pixel_values)
77
+ vision_hidden_states = vision_model_outputs.last_hidden_state # [B*5, 1024, 768]
78
+
79
+ # Reshape to separate views
80
+ vision_hidden_states = vision_hidden_states.view(batch_size, scale, *vision_hidden_states.shape[1:]) # [B, 5, 1024, 768]
81
+
82
+ # Process quarter views
83
+ quarters = vision_hidden_states[:, :4] # [B, 4, 1024, 768]
84
+ quarters = quarters.view(batch_size, 4, num_patches, num_patches, -1) # [B, 4, 32, 32, 768]
85
+
86
+ # Combine quarter views into full image
87
+ upper = torch.cat([quarters[:, 0], quarters[:, 1]], dim=2) # [B, 32, 64, 768]
88
+ lower = torch.cat([quarters[:, 2], quarters[:, 3]], dim=2) # [B, 32, 64, 768]
89
+ pooled_image = torch.cat([upper, lower], dim=1) # [B, 64, 64, 768]
90
+ pooled_image = pooled_image.permute(0, 3, 1, 2) # [B, 768, 64, 64]
91
+
92
+ # Downsample
93
+ pooled32 = self.downsampler(pooled_image) # [B, 768, 32, 32]
94
+ pooled_tok = pooled32.flatten(2).transpose(1, 2) # [B, 1024, 768]
95
+
96
+ # Full image features
97
+ full_image = vision_hidden_states[:, 4] # [B, 1024, 768]
98
+
99
+ # Fuse quarter and full views
100
+ concat = torch.cat([pooled_tok, full_image], dim=-1) # [B, 1024, 1536]
101
+ fused = self.fuse(concat) # [B, 1024, 768]
102
+
103
+ # Add positional encoding and project
104
+ fused = fused + self.pos_embedding
105
+ vision_hidden_states = self.linear_projection(fused) # [B, 1024, 768]
106
+
107
+ return vision_hidden_states
108
+
109
+ def forward(
110
+ self,
111
+ pixel_values: torch.Tensor,
112
+ input_ids: Optional[torch.LongTensor] = None,
113
+ labels: Optional[torch.LongTensor] = None,
114
+ decoder_input_ids: Optional[torch.LongTensor] = None,
115
+ past_key_values: Optional[Tuple] = None,
116
+ attention_mask: Optional[torch.Tensor] = None,
117
+ **kwargs
118
+ ) -> Seq2SeqLMOutput:
119
+ """
120
+ Forward pass for the model.
121
+
122
+ Args:
123
+ pixel_values: Vision input images
124
+ input_ids: Decoder input token IDs
125
+ labels: Target token IDs for training
126
+ decoder_input_ids: Decoder input IDs (used during generation)
127
+ past_key_values: Cached key values for efficient generation
128
+ attention_mask: Attention mask for decoder inputs
129
+
130
+ Returns:
131
+ Seq2SeqLMOutput with loss, logits, and generation-related outputs
132
+ """
133
+ # Encode images
134
+ encoder_hidden_states = self._encode_vision(pixel_values)
135
+
136
+ # Prepare decoder input IDs
137
+ if decoder_input_ids is None and input_ids is None:
138
+ decoder_input_ids = self._get_decoder_start_token_id()
139
+ decoder_input_ids = torch.full(
140
+ (pixel_values.shape[0] // 5, 1),
141
+ decoder_input_ids,
142
+ dtype=torch.long,
143
+ device=pixel_values.device
144
+ )
145
+
146
+ if decoder_input_ids is None and input_ids is not None:
147
+ decoder_input_ids = self._shift_right(input_ids)
148
+
149
+ # Pass through decoder
150
+ decoder_outputs = self.seq2seq_decoder(
151
+ input_ids=decoder_input_ids,
152
+ encoder_hidden_states=encoder_hidden_states,
153
+ past_key_values=past_key_values,
154
+ attention_mask=attention_mask,
155
+ )
156
+
157
+ sequence_output = decoder_outputs[0]
158
+ lm_logits = self.seq2seq_lm_head(sequence_output)
159
+
160
+ loss = None
161
+ if labels is not None:
162
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
163
+ labels = labels.to(lm_logits.device)
164
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
165
+
166
+ return Seq2SeqLMOutput(
167
+ loss=loss,
168
+ logits=lm_logits,
169
+ past_key_values=decoder_outputs.past_key_values,
170
+ decoder_hidden_states=decoder_outputs.hidden_states,
171
+ decoder_attentions=decoder_outputs.attentions,
172
+ cross_attentions=decoder_outputs.cross_attentions,
173
+ )
174
+
175
+ def prepare_inputs_for_generation(
176
+ self,
177
+ decoder_input_ids,
178
+ past_key_values=None,
179
+ attention_mask=None,
180
+ use_cache=None,
181
+ encoder_outputs=None,
182
+ **kwargs
183
+ ):
184
+ """Prepare inputs for generation."""
185
+ # Cut decoder_input_ids if past is used
186
+ if past_key_values is not None:
187
+ decoder_input_ids = decoder_input_ids[:, -1:]
188
+
189
+ return {
190
+ "input_ids": None, # encoder_outputs is already defined
191
+ "encoder_outputs": encoder_outputs,
192
+ "past_key_values": past_key_values,
193
+ "decoder_input_ids": decoder_input_ids,
194
+ "attention_mask": attention_mask,
195
+ "use_cache": use_cache,
196
+ }
197
+
198
+ def _prepare_encoder_decoder_kwargs_for_generation(
199
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
200
+ ):
201
+ """Encode pixel values to get encoder outputs."""
202
+ # Encode images if not already done
203
+ if "encoder_outputs" not in model_kwargs:
204
+ encoder_outputs = self._encode_vision(inputs_tensor)
205
+ model_kwargs["encoder_outputs"] = (encoder_outputs,)
206
+
207
+ return model_kwargs
208
+
209
+ def generate(
210
+ self,
211
+ pixel_values: torch.Tensor,
212
+ max_length: int = 1024,
213
+ num_beams: int = 1,
214
+ temperature: float = 1.0,
215
+ do_sample: bool = False,
216
+ **kwargs
217
+ ) -> torch.LongTensor:
218
+ """
219
+ Generate text from image inputs.
220
+
221
+ Args:
222
+ pixel_values: Input images [B*5, 3, 512, 512]
223
+ max_length: Maximum generation length
224
+ num_beams: Number of beams for beam search (1 = greedy) TODO: Not implemented
225
+ temperature: Sampling temperature
226
+ do_sample: Whether to use sampling
227
+
228
+ Returns:
229
+ Generated token sequences
230
+ """
231
+ # Encode vision inputs
232
+ encoder_hidden_states = self._encode_vision(pixel_values)
233
+ batch_size = pixel_values.shape[0] // 5
234
+
235
+ # Start with decoder_start_token_id
236
+ decoder_input_ids = torch.full(
237
+ (batch_size, 1),
238
+ self._get_decoder_start_token_id(),
239
+ dtype=torch.long,
240
+ device=pixel_values.device
241
+ )
242
+
243
+ generated_tokens = []
244
+ past_key_values = None
245
+
246
+ for step in range(max_length):
247
+ with torch.no_grad():
248
+ # Get decoder outputs
249
+ decoder_outputs = self.seq2seq_decoder(
250
+ input_ids=decoder_input_ids if past_key_values is None else decoder_input_ids[:, -1:],
251
+ encoder_hidden_states=encoder_hidden_states,
252
+ past_key_values=past_key_values,
253
+ use_cache=True,
254
+ )
255
+
256
+ past_key_values = decoder_outputs.past_key_values
257
+
258
+ # Get logits and generate next token
259
+ hidden_states = decoder_outputs[0][:, -1:, :]
260
+ lm_logits = self.seq2seq_lm_head(hidden_states)
261
+
262
+ # Apply temperature
263
+ if temperature != 1.0:
264
+ lm_logits = lm_logits / temperature
265
+
266
+ # Get next token
267
+ if do_sample:
268
+ probs = torch.softmax(lm_logits[:, -1, :], dim=-1)
269
+ next_token = torch.multinomial(probs, num_samples=1)
270
+ else:
271
+ next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True)
272
+
273
+ # Append to generated tokens
274
+ generated_tokens.append(next_token)
275
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
276
+
277
+ # Check for EOS
278
+ if (next_token == self.config.eos_token_id).all():
279
+ break
280
+
281
+ return decoder_input_ids
282
+
283
+ def _get_decoder_start_token_id(self) -> int:
284
+ """Get decoder start token ID."""
285
+ return self.config.decoder_start_token_id or self.config.pad_token_id
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee68e205cbe66657917c8b719a66e0798425d2ba696ab9e23f81b6f8bbb7875
3
+ size 758546423