Spaces:
Running on Zero
Running on Zero
Update decoder.py
Browse files- decoder.py +11 -3
decoder.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch
|
| 3 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
|
| 4 |
-
|
|
|
|
| 5 |
class SketchDecoder(nn.Module):
|
| 6 |
"""
|
| 7 |
Autoregressive generative model
|
|
@@ -23,10 +24,17 @@ class SketchDecoder(nn.Module):
|
|
| 23 |
eos_token_id=self.eos_token_id,
|
| 24 |
pad_token_id=self.pad_token_id)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 27 |
"Qwen/Qwen2.5-VL-3B-Instruct",
|
| 28 |
config=config,
|
| 29 |
-
|
|
|
|
| 30 |
device_map ="cuda",
|
| 31 |
ignore_mismatched_sizes=True
|
| 32 |
)
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch
|
| 3 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig,BitsAndBytesConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
class SketchDecoder(nn.Module):
|
| 7 |
"""
|
| 8 |
Autoregressive generative model
|
|
|
|
| 24 |
eos_token_id=self.eos_token_id,
|
| 25 |
pad_token_id=self.pad_token_id)
|
| 26 |
|
| 27 |
+
quantization_config = BitsAndBytesConfig(
|
| 28 |
+
load_in_4bit=True,
|
| 29 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 30 |
+
bnb_4bit_quant_type="nf4"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 34 |
"Qwen/Qwen2.5-VL-3B-Instruct",
|
| 35 |
config=config,
|
| 36 |
+
quantization_config=quantization_config,
|
| 37 |
+
torch_dtype=torch.bfloat16, attn_implementation="sdpa",
|
| 38 |
device_map ="cuda",
|
| 39 |
ignore_mismatched_sizes=True
|
| 40 |
)
|