tgohblio commited on
Commit
71cbddf
·
1 Parent(s): 9cfcea6

Add support for downloading loras from civitai

Browse files
Files changed (3) hide show
  1. app.py +53 -25
  2. civitai_utils.py +66 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -8,6 +8,7 @@ import logging
8
  import numpy as np
9
  import spaces
10
  from typing import Any, Dict, List, Optional, Union
 
11
 
12
  import torch
13
  from PIL import Image
@@ -167,6 +168,42 @@ def update_selection(evt: gr.SelectData, width, height):
167
  height,
168
  )
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  @spaces.GPU
171
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
172
  # Clean up previous LoRAs in both cases
@@ -176,7 +213,6 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
176
  # Check if a LoRA is selected
177
  if selected_index is not None and selected_index < len(loras):
178
  selected_lora = loras[selected_index]
179
- lora_path = selected_lora["repo"]
180
  trigger_word = selected_lora["trigger_word"]
181
 
182
  # Prepare Prompt with Trigger Word
@@ -188,24 +224,12 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
188
  prompt_mash = f"{prompt} {trigger_word}"
189
  else:
190
  prompt_mash = f"{trigger_word} {prompt}"
191
- else:
192
- prompt_mash = prompt
193
 
194
- # Load LoRA
195
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
196
- weight_name = selected_lora.get("weights", None)
197
- try:
198
- pipe.load_lora_weights(
199
- lora_path,
200
- weight_name=weight_name,
201
- adapter_name="default",
202
- low_cpu_mem_usage=True
203
- )
204
- # Set adapter scale
205
- pipe.set_adapters(["default"], adapter_weights=[lora_scale])
206
- except Exception as e:
207
- print(f"Error loading LoRA: {e}")
208
- gr.Warning("Failed to load LoRA weights. Generating with base model.")
209
  else:
210
  # Base Model Case
211
  print("No LoRA selected. Running with Base Model.")
@@ -236,12 +260,12 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
236
  yield final_image, seed, gr.update(visible=False)
237
 
238
  def get_huggingface_safetensors(link):
239
- split_link = link.split("/")
240
- if(len(split_link) == 2):
241
  model_card = ModelCard.load(link)
242
  base_model = model_card.data.get("base_model")
243
  print(base_model)
244
-
245
  # Relaxed check to allow Z-Image or Flux or others, assuming user knows what they are doing
246
  # or specifically check for Z-Image-Turbo
247
  if base_model not in ["Tongyi-MAI/Z-Image-Turbo", "black-forest-labs/FLUX.1-dev"]:
@@ -264,21 +288,23 @@ def get_huggingface_safetensors(link):
264
  print(e)
265
  gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
266
  raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
267
- return split_link[1], link, safetensors_name, trigger_word, image_url
268
 
269
  def check_custom_model(link):
270
  if(link.startswith("https://")):
271
  if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
272
  link_split = link.split("huggingface.co/")
273
  return get_huggingface_safetensors(link_split[1])
 
 
274
  else:
275
- return get_huggingface_safetensors(link)
276
 
277
  def add_custom_lora(custom_lora):
278
  global loras
279
  if(custom_lora):
280
  try:
281
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
282
  print(f"Loaded custom LoRA: {repo}")
283
  card = f'''
284
  <div class="custom_lora_card">
@@ -299,7 +325,9 @@ def add_custom_lora(custom_lora):
299
  "title": title,
300
  "repo": repo,
301
  "weights": path,
302
- "trigger_word": trigger_word
 
 
303
  }
304
  print(new_item)
305
  existing_item_index = len(loras)
 
8
  import numpy as np
9
  import spaces
10
  from typing import Any, Dict, List, Optional, Union
11
+ from civitai_utils import get_civitai_safetensors, LORA_CHECKPOINTS_CACHE
12
 
13
  import torch
14
  from PIL import Image
 
168
  height,
169
  )
170
 
171
+
172
+ def load_lora_from_hub(lora: dict, lora_scale: float):
173
+ """Load LoRA weights from huggingface hub"""
174
+ with calculateDuration(f"Loading LoRA weights for {lora.get('title')}"):
175
+ try:
176
+ pipe.load_lora_weights(
177
+ lora.get("repo", ""),
178
+ weight_name=lora.get("weights", None),
179
+ adapter_name="default",
180
+ low_cpu_mem_usage=True
181
+ )
182
+ # Set adapter scale
183
+ pipe.set_adapters(["default"], adapter_weights=[lora_scale])
184
+ except Exception as e:
185
+ print(f"Error loading LoRA: {e}")
186
+ gr.Warning("Failed to load LoRA weights. Generating with base model.")
187
+
188
+
189
+ def load_local_lora(lora: dict, lora_scale: float):
190
+ """Load LoRA weights from local cache folder"""
191
+ with calculateDuration(f"Loading LoRA weights for {lora.get('title')}"):
192
+ try:
193
+ pipe.load_lora_weights(
194
+ LORA_CHECKPOINTS_CACHE,
195
+ cache_dir=LORA_CHECKPOINTS_CACHE,
196
+ weight_name=lora.get("weights", None),
197
+ local_files_only=True,
198
+ low_cpu_mem_usage=True
199
+ )
200
+ # Set adapter scale
201
+ pipe.set_adapters(["default"], adapter_weights=[lora_scale])
202
+ except Exception as e:
203
+ print(f"Error loading LoRA: {e}")
204
+ gr.Warning("Failed to load LoRA weights. Generating with base model.")
205
+
206
+
207
  @spaces.GPU
208
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
209
  # Clean up previous LoRAs in both cases
 
213
  # Check if a LoRA is selected
214
  if selected_index is not None and selected_index < len(loras):
215
  selected_lora = loras[selected_index]
 
216
  trigger_word = selected_lora["trigger_word"]
217
 
218
  # Prepare Prompt with Trigger Word
 
224
  prompt_mash = f"{prompt} {trigger_word}"
225
  else:
226
  prompt_mash = f"{trigger_word} {prompt}"
 
 
227
 
228
+ # Special handling of lora loading if there's a civitai key
229
+ if selected_lora.get("src") == "civitai":
230
+ load_local_lora(selected_lora, lora_scale)
231
+ else:
232
+ load_lora_from_hub(selected_lora, lora_scale)
 
 
 
 
 
 
 
 
 
 
233
  else:
234
  # Base Model Case
235
  print("No LoRA selected. Running with Base Model.")
 
260
  yield final_image, seed, gr.update(visible=False)
261
 
262
  def get_huggingface_safetensors(link):
263
+ split_link = link.split("/")
264
+ if(len(split_link) == 2):
265
  model_card = ModelCard.load(link)
266
  base_model = model_card.data.get("base_model")
267
  print(base_model)
268
+
269
  # Relaxed check to allow Z-Image or Flux or others, assuming user knows what they are doing
270
  # or specifically check for Z-Image-Turbo
271
  if base_model not in ["Tongyi-MAI/Z-Image-Turbo", "black-forest-labs/FLUX.1-dev"]:
 
288
  print(e)
289
  gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
290
  raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
291
+ return split_link[1], link, safetensors_name, trigger_word, image_url, "", False
292
 
293
  def check_custom_model(link):
294
  if(link.startswith("https://")):
295
  if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
296
  link_split = link.split("huggingface.co/")
297
  return get_huggingface_safetensors(link_split[1])
298
+ elif "civitai" in link:
299
+ return get_civitai_safetensors(link)
300
  else:
301
+ return ""
302
 
303
  def add_custom_lora(custom_lora):
304
  global loras
305
  if(custom_lora):
306
  try:
307
+ title, repo, path, trigger_word, image, src, nsfw = check_custom_model(custom_lora)
308
  print(f"Loaded custom LoRA: {repo}")
309
  card = f'''
310
  <div class="custom_lora_card">
 
325
  "title": title,
326
  "repo": repo,
327
  "weights": path,
328
+ "trigger_word": trigger_word,
329
+ "src": src,
330
+ "nsfw": nsfw
331
  }
332
  print(new_item)
333
  existing_item_index = len(loras)
civitai_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re
2
+ import gradio as gr
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv() # load CIVITAI_API_TOKEN as environment variable
6
+ from civitai_downloader import civitai_download, APIClient
7
+
8
+ LORA_CHECKPOINTS_CACHE = os.path.join(os.getcwd(), "loras")
9
+
10
+ def verify_url(url: str) -> Tuple[int, int]:
11
+ """
12
+ Verify if link adheres to the format below:
13
+ https://civitai.com/models/<number>?modelVersionId=<number>
14
+
15
+ returns (models id, model version id)
16
+ """
17
+ models_id = 0
18
+ model_version_id = 0
19
+
20
+ if "civitai" in url and "models" in url and "download" not in url:
21
+ # Extract the first sequence of numbers
22
+ match = re.search(r'\d+', url)
23
+ if match:
24
+ models_id = match.group(0)
25
+
26
+ # Extract the second sequence of numbers
27
+ match2 = re.search(r'(?:\d+.*?)(\d+)', url)
28
+ if match2:
29
+ model_version_id = match2.group(1)
30
+ return int(models_id), int(model_version_id)
31
+
32
+
33
+ def get_model_info(models_id: int, model_version_id: int) -> dict:
34
+ api = APIClient()
35
+ item = {}
36
+ item["title"] = api.get_model(models_id).name
37
+ item["src"] = "civitai"
38
+ get_model_version_details(item, model_version_id)
39
+ return item
40
+
41
+
42
+ def get_model_version_details(item:dict, model_version_id: int) -> None:
43
+ api = APIClient()
44
+ model_version = api.get_model_version(model_version_id)
45
+ item["image"] = model_version.images[0].url
46
+ item["repo"] = LORA_CHECKPOINTS_CACHE # local path to safetensor
47
+ item["weights"] = model_version.files[0].name
48
+ item["trigger_word"] = model_version.trainedWords[0]
49
+ item["nsfw"] = model_version.model.get("nsfw", False)
50
+
51
+
52
+ def get_civitai_safetensors(url: str):
53
+ """Helper function to download lora weights"""
54
+ models_id, model_version_id = verify_url(url)
55
+ if models_id != 0 and model_version_id != 0:
56
+ info = get_model_info(models_id, model_version_id)
57
+ print(info)
58
+
59
+ # make folder and download
60
+ os.makedirs(LORA_CHECKPOINTS_CACHE, exist_ok=True)
61
+ civitai_download(model_version_id, LORA_CHECKPOINTS_CACHE)
62
+ return info.get("title"), info.get("repo"), info.get("path"), info.get("trigger_word"), info.get("image"), info.get("src"), info.get("nsfw")
63
+ else:
64
+ print("Invalid URL - must be in this format https://civitai.com/models/<number>?modelVersionId=<number>")
65
+ gr.Warning(f"Invalid URL - must be in this format https://civitai.com/models/<number>?modelVersionId=<number>")
66
+ return "", "", "", "", "", "", False
requirements.txt CHANGED
@@ -10,4 +10,6 @@ kernels
10
  spaces
11
  torch
12
  numpy
13
- peft
 
 
 
10
  spaces
11
  torch
12
  numpy
13
+ peft
14
+ python-dotenv
15
+ civitai-model-downloader==0.4.7