Leofreddare commited on
Commit
929d5bb
·
verified ·
1 Parent(s): a866a42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -1,36 +1,36 @@
1
  import torch
2
  import gradio as gr
3
- from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
4
- from peft import inject_adapter_in_model
5
  from huggingface_hub import hf_hub_download
6
- import os
7
 
8
- # Download LoRA from HF Hub
9
- lora_path = hf_hub_download(repo_id="Leofreddare/DreamCartoonLora", filename="DreamCartoonLora.safetensors")
 
 
 
10
 
11
- # Load SDXL base pipeline
12
  pipe = StableDiffusionXLPipeline.from_pretrained(
13
  "stabilityai/stable-diffusion-xl-base-1.0",
14
- torch_dtype=torch.float16,
15
- variant="fp16",
16
  use_safetensors=True
17
- ).to("cuda")
18
 
19
- # Load LoRA
20
- pipe.load_lora_weights(lora_path)
21
-
22
- # Set up scheduler
23
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
24
 
 
25
  def generate(prompt):
26
- image = pipe(prompt=prompt).images[0]
 
27
  return image
28
 
29
- # Gradio UI
30
  gr.Interface(
31
  fn=generate,
32
- inputs=gr.Textbox(label="Prompt"),
33
  outputs="image",
34
  title="DreamCartoonLora - SDXL 1.0",
35
- description="Generate cartoon-style images using a fine-tuned LoRA on SDXL 1.0"
36
  ).launch()
 
 
1
  import torch
2
  import gradio as gr
3
+ from diffusers import StableDiffusionXLPipeline
 
4
  from huggingface_hub import hf_hub_download
 
5
 
6
+ # Download your LoRA weights
7
+ lora_weights = hf_hub_download(
8
+ repo_id="Leofreddare/DreamCartoonLora",
9
+ filename="DreamCartoonLora.safetensors"
10
+ )
11
 
12
+ # Load base SDXL pipeline on CPU
13
  pipe = StableDiffusionXLPipeline.from_pretrained(
14
  "stabilityai/stable-diffusion-xl-base-1.0",
15
+ torch_dtype=torch.float32, # Use float32 for CPU
 
16
  use_safetensors=True
17
+ ).to("cpu")
18
 
19
+ # Load the LoRA weights
20
+ pipe.load_lora_weights(lora_weights)
 
 
 
21
 
22
+ # Inference function
23
  def generate(prompt):
24
+ with torch.no_grad():
25
+ image = pipe(prompt=prompt).images[0]
26
  return image
27
 
28
+ # Simple UI
29
  gr.Interface(
30
  fn=generate,
31
+ inputs=gr.Textbox(label="Enter your prompt"),
32
  outputs="image",
33
  title="DreamCartoonLora - SDXL 1.0",
34
+ description="Cartoon LoRA applied to Stable Diffusion XL 1.0"
35
  ).launch()
36
+