shiue2000 commited on
Commit
0045776
·
verified ·
1 Parent(s): 85c8812

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -8,15 +8,37 @@ import uuid
8
  from diffusers import StableDiffusionPipeline
9
  import torch
10
 
 
11
  OUTPUT_DIR = "outputs"
12
  os.makedirs(OUTPUT_DIR, exist_ok=True)
13
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- pipe = StableDiffusionPipeline.from_pretrained(
16
- "Lykon/anything-cartoon", torch_dtype=torch.float16 if device=="cuda" else torch.float32
17
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  pipe.to(device)
19
 
 
20
  def load_image(file):
21
  ext = os.path.splitext(file.name)[1].lower()
22
  if ext in [".heic", ".heif"]:
@@ -100,6 +122,7 @@ def apply_filters(file, mode):
100
 
101
  return processed, output_path
102
 
 
103
  demo = gr.Interface(
104
  fn=apply_filters,
105
  inputs=[
 
8
  from diffusers import StableDiffusionPipeline
9
  import torch
10
 
11
+ # --- OUTPUT DIRECTORY ---
12
  OUTPUT_DIR = "outputs"
13
  os.makedirs(OUTPUT_DIR, exist_ok=True)
14
 
15
+ # --- DEVICE SETUP ---
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # --- LOAD STABLE DIFFUSION PIPELINE ---
19
+ MODEL_NAME = "Lykon/anything-cartoon" # original model
20
+ PUBLIC_MODEL = "runwayml/stable-diffusion-v1-5" # fallback
21
+
22
+ # Use HF token if required
23
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # set environment variable HF_TOKEN if private repo
24
+
25
+ try:
26
+ pipe = StableDiffusionPipeline.from_pretrained(
27
+ MODEL_NAME,
28
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
29
+ use_auth_token=HF_TOKEN
30
+ )
31
+ except Exception as e:
32
+ print(f"Failed to load '{MODEL_NAME}': {e}")
33
+ print(f"Falling back to public model '{PUBLIC_MODEL}'")
34
+ pipe = StableDiffusionPipeline.from_pretrained(
35
+ PUBLIC_MODEL,
36
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
37
+ )
38
+
39
  pipe.to(device)
40
 
41
+ # --- HELPER FUNCTIONS ---
42
  def load_image(file):
43
  ext = os.path.splitext(file.name)[1].lower()
44
  if ext in [".heic", ".heif"]:
 
122
 
123
  return processed, output_path
124
 
125
+ # --- GRADIO INTERFACE ---
126
  demo = gr.Interface(
127
  fn=apply_filters,
128
  inputs=[