Your Name commited on
Commit
56a1f31
·
1 Parent(s): 2cc59cd
Files changed (1) hide show
  1. src/pipeline.py +53 -18
src/pipeline.py CHANGED
@@ -30,35 +30,66 @@ REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
30
  Pipeline = None
31
  apply_quanto=1
32
 
33
- def reset_cache():
34
- gc.collect()
 
 
 
 
 
 
35
  torch.cuda.empty_cache()
36
  torch.cuda.reset_max_memory_allocated()
37
  torch.cuda.reset_peak_memory_stats()
 
38
 
39
- def load_quanto_text_encoder_2(text_repo_path):
40
- with open("quantization_map.json", "r") as f:
41
- quantization_map = json.load(f)
42
- with open(os.path.join(text_repo_path, "config.json"), "r") as f:
43
- t5_config = transformers.T5Config(**json.load(f))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with torch.device("meta"):
45
  text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
46
- state_dict = None
47
- requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
 
 
 
 
 
 
 
 
48
  return text_encoder_2
49
 
 
50
  def load_pipeline() -> Pipeline:
51
 
52
  try:
53
- text_repo_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_T5_bf16/snapshots/63a3d9ef7b586655600ac9bd4e4747d038237761")
54
- text_encoder_2 = load_quanto_text_encoder_2(text_repo_path=text_repo_path)
55
  except:
56
- text_encoder_2 = T5EncoderModel.from_pretrained("RichardWilliam/XULF_T5_bf16",
57
  revision = "63a3d9ef7b586655600ac9bd4e4747d038237761",
58
  torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
59
 
60
  origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae",
61
- revision="3ee225c539465c27adadec45c6e8af50a7397b7d",
62
  torch_dtype=torch.bfloat16)
63
 
64
 
@@ -75,20 +106,24 @@ def load_pipeline() -> Pipeline:
75
  text_encoder_2=text_encoder_2,
76
  torch_dtype=torch.bfloat16)
77
  pipeline.to("cuda")
 
 
 
 
78
 
79
  for __ in range(3):
80
  pipeline(prompt="sweet, subordinative, gender, mormyre, arteriolosclerosis, positivism, Antiochianism, palmerite",
81
- width=1024,
82
- height=1024,
83
- guidance_scale=0.0,
84
- num_inference_steps=4,
85
  max_sequence_length=256)
86
  return pipeline
87
 
88
  @torch.no_grad()
89
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
90
 
91
- reset_cache()
92
 
93
  generator = Generator(pipeline.device).manual_seed(request.seed)
94
 
 
30
  Pipeline = None
31
  apply_quanto=1
32
 
33
+ import torch
34
+ import gc
35
+ import os
36
+ import json
37
+ import transformers
38
+
39
+ def reset_caching():
40
+ """Clears GPU cache and resets memory statistics."""
41
  torch.cuda.empty_cache()
42
  torch.cuda.reset_max_memory_allocated()
43
  torch.cuda.reset_peak_memory_stats()
44
+ gc.collect()
45
 
46
+ def quanto_T5(_path):
47
+ """
48
+ Loads and prepares the T5 encoder model with quantization mapping.
49
+
50
+ Args:
51
+ _path (str): Path to the text repository containing the config file.
52
+
53
+ Returns:
54
+ transformers.T5EncoderModel: The loaded and quantized T5 encoder model.
55
+ """
56
+ # Load quantization map
57
+ with open("quantization_map.json", "r") as quant_file:
58
+ quantization_map = json.load(quant_file)
59
+
60
+ # Load T5 configuration
61
+ config_path = os.path.join(_path, "config.json")
62
+ with open(config_path, "r") as config_file:
63
+ t5_config = transformers.T5Config(**json.load(config_file))
64
+
65
+ # Initialize model on meta device with bfloat16 precision
66
  with torch.device("meta"):
67
  text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
68
+
69
+ # Apply quantization mapping
70
+ # Assuming state_dict is loaded or handled within `requantize`
71
+ requantize(
72
+ model=text_encoder_2,
73
+ state_dict=None,
74
+ quantization_map=quantization_map,
75
+ device=torch.device("cuda")
76
+ )
77
+
78
  return text_encoder_2
79
 
80
+
81
  def load_pipeline() -> Pipeline:
82
 
83
  try:
84
+ _path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_T5_bf16/snapshots/63a3d9ef7b586655600ac9bd4e4747d038237761")
85
+ text_encoder_2 = quanto_T5(_path=_path)
86
  except:
87
+ text_encoder_2 = T5EncoderModel.from_pretrained("RichardWilliam/XULF_T5_bf16",
88
  revision = "63a3d9ef7b586655600ac9bd4e4747d038237761",
89
  torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
90
 
91
  origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae",
92
+ revision="3ee225c539465c27adadec45c6e8af50a7397b7d",
93
  torch_dtype=torch.bfloat16)
94
 
95
 
 
106
  text_encoder_2=text_encoder_2,
107
  torch_dtype=torch.bfloat16)
108
  pipeline.to("cuda")
109
+ try:
110
+ quantize_(pipeline.origin_vae, int8_weight_only())
111
+ except:
112
+ pass
113
 
114
  for __ in range(3):
115
  pipeline(prompt="sweet, subordinative, gender, mormyre, arteriolosclerosis, positivism, Antiochianism, palmerite",
116
+ width=1024,
117
+ height=1024,
118
+ guidance_scale=0.0,
119
+ num_inference_steps=4,
120
  max_sequence_length=256)
121
  return pipeline
122
 
123
  @torch.no_grad()
124
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
125
 
126
+ reset_caching()
127
 
128
  generator = Generator(pipeline.device).manual_seed(request.seed)
129