OmniSVG commited on
Commit
ac3151e
·
verified ·
1 Parent(s): b8ec255

Update decoder.py

Browse files
Files changed (1) hide show
  1. decoder.py +38 -40
decoder.py CHANGED
@@ -2,45 +2,43 @@ import torch.nn as nn
2
  import torch
3
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig, BitsAndBytesConfig
4
 
5
-
6
  class SketchDecoder(nn.Module):
7
- """
8
- Autoregressive generative model
9
- """
10
-
11
- def __init__(self,
12
- **kwargs):
13
- super().__init__()
14
- self.vocab_size = 196042
15
- self.bos_token_id = 151643
16
- self.eos_token_id = 196041
17
- self.pad_token_id = 151643
18
-
19
- config = AutoConfig.from_pretrained(
20
- "Qwen/Qwen2.5-VL-3B-Instruct",
21
- #n_positions=8192,
22
- vocab_size=self.vocab_size,
23
- bos_token_id=self.bos_token_id,
24
- eos_token_id=self.eos_token_id,
25
- pad_token_id=self.pad_token_id)
26
-
27
- quantization_config = BitsAndBytesConfig(
28
- load_in_4bit=True,
29
- bnb_4bit_compute_dtype=torch.float16,
30
- bnb_4bit_quant_type="nf4"
31
- )
32
-
33
- self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
- "Qwen/Qwen2.5-VL-3B-Instruct",
35
- config=config,
36
- quantization_config=quantization_config,
37
- torch_dtype=torch.bfloat16, attn_implementation="sdpa",
38
- device_map ="auto",
39
- ignore_mismatched_sizes=True
40
- )
41
-
42
- self.transformer.resize_token_embeddings(self.vocab_size)
43
 
44
- def forward(self, *args, **kwargs):
45
- raise NotImplementedError("Forward pass not included in open-source version")
46
-
 
2
  import torch
3
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig, BitsAndBytesConfig
4
 
 
5
  class SketchDecoder(nn.Module):
6
+ """
7
+ Autoregressive generative model
8
+ """
9
+ def __init__(self, state_dict=None, **kwargs):
10
+ super().__init__()
11
+ self.vocab_size = 196042
12
+ self.bos_token_id = 151643
13
+ self.eos_token_id = 196041
14
+ self.pad_token_id = 151643
15
+
16
+ config = AutoConfig.from_pretrained(
17
+ "Qwen/Qwen2.5-VL-3B-Instruct",
18
+ vocab_size=self.vocab_size,
19
+ bos_token_id=self.bos_token_id,
20
+ eos_token_id=self.eos_token_id,
21
+ pad_token_id=self.pad_token_id
22
+ )
23
+
24
+ quantization_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_compute_dtype=torch.float16,
27
+ bnb_4bit_quant_type="nf4"
28
+ )
29
+
30
+ self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
31
+ "Qwen/Qwen2.5-VL-3B-Instruct",
32
+ config=config,
33
+ quantization_config=quantization_config,
34
+ torch_dtype=torch.bfloat16,
35
+ attn_implementation="sdpa",
36
+ device_map="auto",
37
+ ignore_mismatched_sizes=True,
38
+ state_dict=state_dict # <--- 关键修改:直接传入权重字典
39
+ )
40
+
41
+ self.transformer.resize_token_embeddings(self.vocab_size)
42
 
43
+ def forward(self, *args, **kwargs):
44
+ raise NotImplementedError("Forward pass not included in open-source version")