AkashKumarave commited on
Commit
a3c77b4
·
verified ·
1 Parent(s): 086cb3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -30
app.py CHANGED
@@ -8,6 +8,7 @@ from huggingface_hub import hf_hub_download
8
  import os
9
  import logging
10
  from safetensors.torch import load_file
 
11
 
12
  # Set up detailed logging
13
  logging.basicConfig(level=logging.INFO)
@@ -34,24 +35,29 @@ except Exception as e:
34
  logger.error(f"Failed to load InsightFace model: {e}")
35
  raise
36
 
37
- # Download function with explicit path return
38
- def download_file(repo_id, filename, local_dir):
39
  file_path = os.path.join(local_dir, filename)
40
  if not os.path.exists(file_path):
41
- logger.info(f"Downloading {filename} from {repo_id} to {local_dir}...")
42
- try:
43
- downloaded_path = hf_hub_download(
44
- repo_id=repo_id,
45
- filename=filename,
46
- local_dir=local_dir,
47
- cache_dir=cache_dir,
48
- local_files_only=False
49
- )
50
- logger.info(f"Downloaded to {downloaded_path}")
51
- return downloaded_path
52
- except Exception as e:
53
- logger.error(f"Download failed: {e}")
54
- raise
 
 
 
 
 
55
  else:
56
  logger.info(f"Using cached file at {file_path}")
57
  return file_path
@@ -62,7 +68,7 @@ ip_adapter_path = "./"
62
  os.makedirs(kolors_unet_path, exist_ok=True)
63
  os.makedirs(ip_adapter_path, exist_ok=True)
64
 
65
- # Download weights
66
  logger.info("Starting weights download...")
67
  kolors_weights = download_file(
68
  "Kwai-Kolors/Kolors",
@@ -75,28 +81,39 @@ ip_adapter_weights = download_file(
75
  ip_adapter_path
76
  )
77
 
78
- # Load the pipeline
79
  logger.info("Loading Stable Diffusion XL base model...")
80
  try:
81
- pipe = StableDiffusionXLPipeline.from_pretrained(
82
- "stabilityai/stable-diffusion-xl-base-1.0",
83
- torch_dtype=dtype,
84
- safety_checker=None,
85
- local_files_only=False,
86
- cache_dir=cache_dir,
87
- variant="fp16",
88
- use_safetensors=True
89
- )
90
- logger.info("SDXL base model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
  logger.error(f"Failed to load SDXL base model: {e}")
93
  raise
94
 
95
- # Load Kolors unet weights with strict=False to ignore mismatches
96
  logger.info(f"Loading Kolors unet weights from {kolors_weights}...")
97
  try:
98
  state_dict = load_file(kolors_weights, device=device)
99
- # Load with strict=False to ignore unexpected keys and size mismatches
100
  pipe.unet.load_state_dict(state_dict, strict=False)
101
  logger.info("Kolors unet weights loaded successfully (with ignored mismatches).")
102
  except Exception as e:
 
8
  import os
9
  import logging
10
  from safetensors.torch import load_file
11
+ import time
12
 
13
  # Set up detailed logging
14
  logging.basicConfig(level=logging.INFO)
 
35
  logger.error(f"Failed to load InsightFace model: {e}")
36
  raise
37
 
38
+ # Download function with retry logic
39
+ def download_file(repo_id, filename, local_dir, max_retries=3):
40
  file_path = os.path.join(local_dir, filename)
41
  if not os.path.exists(file_path):
42
+ for attempt in range(max_retries):
43
+ logger.info(f"Attempt {attempt + 1}/{max_retries}: Downloading {filename} from {repo_id} to {local_dir}...")
44
+ try:
45
+ downloaded_path = hf_hub_download(
46
+ repo_id=repo_id,
47
+ filename=filename,
48
+ local_dir=local_dir,
49
+ cache_dir=cache_dir,
50
+ local_files_only=False
51
+ )
52
+ logger.info(f"Downloaded to {downloaded_path}")
53
+ return downloaded_path
54
+ except Exception as e:
55
+ logger.error(f"Download attempt {attempt + 1} failed: {e}")
56
+ if attempt < max_retries - 1:
57
+ logger.info("Retrying in 5 seconds...")
58
+ time.sleep(5)
59
+ else:
60
+ raise RuntimeError(f"Failed to download {filename} after {max_retries} attempts: {e}")
61
  else:
62
  logger.info(f"Using cached file at {file_path}")
63
  return file_path
 
68
  os.makedirs(kolors_unet_path, exist_ok=True)
69
  os.makedirs(ip_adapter_path, exist_ok=True)
70
 
71
+ # Download weights with retries
72
  logger.info("Starting weights download...")
73
  kolors_weights = download_file(
74
  "Kwai-Kolors/Kolors",
 
81
  ip_adapter_path
82
  )
83
 
84
+ # Load the pipeline with verbose logging and retry logic
85
  logger.info("Loading Stable Diffusion XL base model...")
86
  try:
87
+ max_retries = 3
88
+ for attempt in range(max_retries):
89
+ try:
90
+ logger.info(f"Attempt {attempt + 1}/{max_retries}: Loading SDXL model...")
91
+ pipe = StableDiffusionXLPipeline.from_pretrained(
92
+ "stabilityai/stable-diffusion-xl-base-1.0",
93
+ torch_dtype=dtype,
94
+ safety_checker=None,
95
+ local_files_only=False,
96
+ cache_dir=cache_dir,
97
+ variant="fp16",
98
+ use_safetensors=True
99
+ )
100
+ logger.info("SDXL base model loaded successfully.")
101
+ break
102
+ except Exception as e:
103
+ logger.error(f"Load attempt {attempt + 1} failed: {e}")
104
+ if attempt < max_retries - 1:
105
+ logger.info("Retrying in 5 seconds...")
106
+ time.sleep(5)
107
+ else:
108
+ raise RuntimeError(f"Failed to load SDXL model after {max_retries} attempts: {e}")
109
  except Exception as e:
110
  logger.error(f"Failed to load SDXL base model: {e}")
111
  raise
112
 
113
+ # Load Kolors unet weights with strict=False
114
  logger.info(f"Loading Kolors unet weights from {kolors_weights}...")
115
  try:
116
  state_dict = load_file(kolors_weights, device=device)
 
117
  pipe.unet.load_state_dict(state_dict, strict=False)
118
  logger.info("Kolors unet weights loaded successfully (with ignored mismatches).")
119
  except Exception as e: