Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- hf_demo.py +5 -3
hf_demo.py
CHANGED
|
@@ -10,7 +10,9 @@ from PIL import Image
|
|
| 10 |
|
| 11 |
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from inference import get_lora_network, inference, get_validation_dataloader
|
| 16 |
lora_map = {
|
|
@@ -33,7 +35,7 @@ lora_map = {
|
|
| 33 |
"Henri Matisse": "henri-matisse_subset1",
|
| 34 |
"Joan Miro": "joan-miro_subset2",
|
| 35 |
}
|
| 36 |
-
|
| 37 |
def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
|
| 38 |
adapter_path = lora_map[adapter_choice]
|
| 39 |
if adapter_path not in [None, "None"]:
|
|
@@ -48,7 +50,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
|
|
| 48 |
start_noise=-1, show=False, style_prompt="sks art", no_load=True,
|
| 49 |
from_scratch=True, device=device)[0][1.0]
|
| 50 |
return pred_images
|
| 51 |
-
|
| 52 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
| 53 |
infer_loader = get_validation_dataloader(prompts, image)
|
| 54 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
| 14 |
+
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
| 15 |
+
dtype=dtype).to(device)
|
| 16 |
|
| 17 |
from inference import get_lora_network, inference, get_validation_dataloader
|
| 18 |
lora_map = {
|
|
|
|
| 35 |
"Henri Matisse": "henri-matisse_subset1",
|
| 36 |
"Joan Miro": "joan-miro_subset2",
|
| 37 |
}
|
| 38 |
+
@spaces.GPU
|
| 39 |
def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
|
| 40 |
adapter_path = lora_map[adapter_choice]
|
| 41 |
if adapter_path not in [None, "None"]:
|
|
|
|
| 50 |
start_noise=-1, show=False, style_prompt="sks art", no_load=True,
|
| 51 |
from_scratch=True, device=device)[0][1.0]
|
| 52 |
return pred_images
|
| 53 |
+
@spaces.GPU
|
| 54 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
| 55 |
infer_loader = get_validation_dataloader(prompts, image)
|
| 56 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|