ford442 commited on
Commit
bd67d99
·
verified ·
1 Parent(s): 1842642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -147,7 +147,16 @@ def infer(
147
  #upscaler_2.to(torch.device('cpu'))
148
  torch.set_float32_matmul_precision("highest")
149
  seed = random.randint(0, MAX_SEED)
150
- generator = torch.Generator(device='cpu').manual_seed(seed)
 
 
 
 
 
 
 
 
 
151
  if expanded:
152
  system_prompt_rewrite = (
153
  "You are an AI assistant that rewrites image prompts to be more descriptive and detailed."
@@ -171,15 +180,7 @@ def infer(
171
  attention_mask_2 = encoded_inputs_2["attention_mask"].to(device)
172
  print("-- tokenize prompt --")
173
  # Google T5
174
- if expanded_only:
175
- pipe.to('cpu')
176
- torch.cuda.empty_cache()
177
- torch.cuda.reset_peak_memory_stats()
178
- else:
179
- torch.cuda.empty_cache()
180
- torch.cuda.reset_peak_memory_stats()
181
- pipe.to(device=device, dtype=torch.bfloat16)
182
- gc.collect()
183
  #input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
184
  outputs = model.generate(
185
  input_ids=input_ids,
 
147
  #upscaler_2.to(torch.device('cpu'))
148
  torch.set_float32_matmul_precision("highest")
149
  seed = random.randint(0, MAX_SEED)
150
+ generator = torch.Generator(device='cuda').manual_seed(seed)
151
+ if expanded_only:
152
+ pipe.to('cpu')
153
+ torch.cuda.empty_cache()
154
+ torch.cuda.reset_peak_memory_stats()
155
+ else:
156
+ torch.cuda.empty_cache()
157
+ torch.cuda.reset_peak_memory_stats()
158
+ pipe.to(device=device, dtype=torch.bfloat16)
159
+ gc.collect()
160
  if expanded:
161
  system_prompt_rewrite = (
162
  "You are an AI assistant that rewrites image prompts to be more descriptive and detailed."
 
180
  attention_mask_2 = encoded_inputs_2["attention_mask"].to(device)
181
  print("-- tokenize prompt --")
182
  # Google T5
183
+
 
 
 
 
 
 
 
 
184
  #input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
185
  outputs = model.generate(
186
  input_ids=input_ids,