Leofreddare commited on
Commit
a866a42
·
verified ·
1 Parent(s): 13c4080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -1,32 +1,36 @@
1
  import torch
2
  import gradio as gr
3
  from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
4
- from peft import PeftModel, PeftConfig
 
 
 
 
 
5
 
6
  # Load SDXL base pipeline
7
- base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
8
  pipe = StableDiffusionXLPipeline.from_pretrained(
9
- base_model_id,
10
  torch_dtype=torch.float16,
11
  variant="fp16",
12
  use_safetensors=True
13
  ).to("cuda")
14
 
15
- # Load LoRA weights
16
- lora_path = "./DreamCartoonLora.safetensors"
17
  pipe.load_lora_weights(lora_path)
18
 
19
- # Optional: set scheduler for faster/cleaner results
20
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
21
 
22
  def generate(prompt):
23
  image = pipe(prompt=prompt).images[0]
24
  return image
25
 
 
26
  gr.Interface(
27
  fn=generate,
28
- inputs=gr.Textbox(label="Enter your prompt"),
29
  outputs="image",
30
- title="Dream Cartoon LoRA - SDXL 1.0",
31
- description="Generate images using the DreamCartoonLora fine-tuned on SDXL 1.0"
32
  ).launch()
 
1
  import torch
2
  import gradio as gr
3
  from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
4
+ from peft import inject_adapter_in_model
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+
8
+ # Download LoRA from HF Hub
9
+ lora_path = hf_hub_download(repo_id="Leofreddare/DreamCartoonLora", filename="DreamCartoonLora.safetensors")
10
 
11
  # Load SDXL base pipeline
 
12
  pipe = StableDiffusionXLPipeline.from_pretrained(
13
+ "stabilityai/stable-diffusion-xl-base-1.0",
14
  torch_dtype=torch.float16,
15
  variant="fp16",
16
  use_safetensors=True
17
  ).to("cuda")
18
 
19
+ # Load LoRA
 
20
  pipe.load_lora_weights(lora_path)
21
 
22
+ # Set up scheduler
23
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
24
 
25
  def generate(prompt):
26
  image = pipe(prompt=prompt).images[0]
27
  return image
28
 
29
+ # Gradio UI
30
  gr.Interface(
31
  fn=generate,
32
+ inputs=gr.Textbox(label="Prompt"),
33
  outputs="image",
34
+ title="DreamCartoonLora - SDXL 1.0",
35
+ description="Generate cartoon-style images using a fine-tuned LoRA on SDXL 1.0"
36
  ).launch()