Spaces:
Runtime error
Runtime error
Update opensora/serve/gradio_web_server.py
Browse files
opensora/serve/gradio_web_server.py
CHANGED
|
@@ -72,8 +72,23 @@ if __name__ == '__main__':
|
|
| 72 |
vae.latent_size = latent_size
|
| 73 |
transformer_model.force_images = args.force_images
|
| 74 |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# set eval mode
|
| 79 |
transformer_model.eval()
|
|
|
|
| 72 |
vae.latent_size = latent_size
|
| 73 |
transformer_model.force_images = args.force_images
|
| 74 |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
|
| 75 |
+
|
| 76 |
+
load_8bit, load_4bit = True, False
|
| 77 |
+
kwargs = {"device_map": "auto"}
|
| 78 |
+
if load_8bit:
|
| 79 |
+
kwargs['load_in_8bit'] = True
|
| 80 |
+
elif load_4bit:
|
| 81 |
+
from transformers import BitsAndBytesConfig
|
| 82 |
+
kwargs['load_in_4bit'] = True
|
| 83 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
| 84 |
+
load_in_4bit=True,
|
| 85 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 86 |
+
bnb_4bit_use_double_quant=True,
|
| 87 |
+
bnb_4bit_quant_type='nf4'
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
kwargs['torch_dtype'] = torch.float16
|
| 91 |
+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", **kwargs)
|
| 92 |
|
| 93 |
# set eval mode
|
| 94 |
transformer_model.eval()
|