OmniSVG commited on
Commit
3d22eb7
·
verified ·
1 Parent(s): a6f884b

Update decoder.py

Browse files
Files changed (1) hide show
  1. decoder.py +11 -3
decoder.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch.nn as nn
2
  import torch
3
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
4
-
 
5
  class SketchDecoder(nn.Module):
6
  """
7
  Autoregressive generative model
@@ -23,10 +24,17 @@ class SketchDecoder(nn.Module):
23
  eos_token_id=self.eos_token_id,
24
  pad_token_id=self.pad_token_id)
25
 
 
 
 
 
 
 
26
  self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
  "Qwen/Qwen2.5-VL-3B-Instruct",
28
  config=config,
29
- torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
 
30
  device_map ="cuda",
31
  ignore_mismatched_sizes=True
32
  )
 
1
  import torch.nn as nn
2
  import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig,BitsAndBytesConfig
4
+
5
+
6
  class SketchDecoder(nn.Module):
7
  """
8
  Autoregressive generative model
 
24
  eos_token_id=self.eos_token_id,
25
  pad_token_id=self.pad_token_id)
26
 
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float16,
30
+ bnb_4bit_quant_type="nf4"
31
+ )
32
+
33
  self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
  "Qwen/Qwen2.5-VL-3B-Instruct",
35
  config=config,
36
+ quantization_config=quantization_config,
37
+ torch_dtype=torch.bfloat16, attn_implementation="sdpa",
38
  device_map ="cuda",
39
  ignore_mismatched_sizes=True
40
  )