Spaces:
Runtime error
Runtime error
fixed cpu error
Browse files
app.py
CHANGED
|
@@ -80,8 +80,11 @@ def prepare_pipeline(model_name):
|
|
| 80 |
if 'dpo' in OUTPUT_DIR:
|
| 81 |
args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
pipe.verbose = True
|
| 87 |
pipe.v = 're'
|
|
@@ -116,7 +119,7 @@ def prepare_pipeline(model_name):
|
|
| 116 |
ID2NAME = open('data/dogs/class_names.txt').readlines()
|
| 117 |
ID2NAME = [line.strip() for line in ID2NAME]
|
| 118 |
|
| 119 |
-
return pipe, MAPPING, ID2NAME
|
| 120 |
|
| 121 |
|
| 122 |
def download_file(url, local_path):
|
|
@@ -159,11 +162,11 @@ def process_text(text, MAPPING, ID2NAME):
|
|
| 159 |
|
| 160 |
|
| 161 |
def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
|
| 162 |
-
generator = torch.Generator(device='cuda')
|
| 163 |
-
generator = generator.manual_seed(int(seed))
|
| 164 |
-
|
| 165 |
try:
|
| 166 |
-
pipe, MAPPING, ID2NAME = prepare_pipeline(model_name)
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
|
| 169 |
negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
|
|
@@ -179,7 +182,8 @@ def generate_images(model_name, prompt, negative_prompt, num_inference_steps, gu
|
|
| 179 |
f"The error message: {e}")
|
| 180 |
finally:
|
| 181 |
gc.collect()
|
| 182 |
-
torch.cuda.
|
|
|
|
| 183 |
|
| 184 |
return images, '; '.join(part2id)
|
| 185 |
|
|
|
|
| 80 |
if 'dpo' in OUTPUT_DIR:
|
| 81 |
args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
|
| 82 |
|
| 83 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 84 |
+
weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 85 |
+
|
| 86 |
+
pipe = load_pipeline(args, weight_dtype, device)
|
| 87 |
+
pipe = pipe.to(weight_dtype)
|
| 88 |
|
| 89 |
pipe.verbose = True
|
| 90 |
pipe.v = 're'
|
|
|
|
| 119 |
ID2NAME = open('data/dogs/class_names.txt').readlines()
|
| 120 |
ID2NAME = [line.strip() for line in ID2NAME]
|
| 121 |
|
| 122 |
+
return pipe, MAPPING, ID2NAME, device
|
| 123 |
|
| 124 |
|
| 125 |
def download_file(url, local_path):
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
|
|
|
|
|
|
|
|
|
|
| 165 |
try:
|
| 166 |
+
pipe, MAPPING, ID2NAME, device = prepare_pipeline(model_name)
|
| 167 |
+
|
| 168 |
+
generator = torch.Generator(device=device)
|
| 169 |
+
generator = generator.manual_seed(int(seed))
|
| 170 |
|
| 171 |
prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
|
| 172 |
negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
|
|
|
|
| 182 |
f"The error message: {e}")
|
| 183 |
finally:
|
| 184 |
gc.collect()
|
| 185 |
+
if torch.cuda.is_available():
|
| 186 |
+
torch.cuda.empty_cache()
|
| 187 |
|
| 188 |
return images, '; '.join(part2id)
|
| 189 |
|