Yanderu commited on
Commit
d7c450e
·
1 Parent(s): 3a84aea

Upload civitai-api.py

Browse files
Files changed (1) hide show
  1. scripts/civitai-api.py +489 -0
scripts/civitai-api.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from fake_useragent import UserAgent as ua
5
+ import json
6
+ import modules.scripts as scripts
7
+ import gradio as gr
8
+ from modules import script_callbacks
9
+ import time
10
+ import threading
11
+ import urllib.request
12
+ import urllib.error
13
+ import os
14
+ from tqdm import tqdm
15
+ import re
16
+ from requests.exceptions import ConnectionError
17
+ import urllib.request
18
+
19
+ PLACEHOLDER = "<no select>"
20
+
21
+ def download_file(url, file_name):
22
+ # Maximum number of retries
23
+ max_retries = 5
24
+
25
+ # Delay between retries (in seconds)
26
+ retry_delay = 10
27
+
28
+ while True:
29
+ # Check if the file has already been partially downloaded
30
+ if os.path.exists(file_name):
31
+ # Get the size of the downloaded file
32
+ downloaded_size = os.path.getsize(file_name)
33
+
34
+ # Set the range of the request to start from the current
35
+ # size of the downloaded file
36
+ headers = {"Range": f"bytes={downloaded_size}-"}
37
+ else:
38
+ downloaded_size = 0
39
+ headers = {}
40
+
41
+ # Split filename from included path
42
+ tokens = re.split(re.escape('\\'), file_name)
43
+ file_name_display = tokens[-1]
44
+
45
+ # Initialize the progress bar
46
+ progress = tqdm(total=1000000000, unit="B", unit_scale=True,
47
+ desc=f"Downloading {file_name_display}",
48
+ initial=downloaded_size, leave=False)
49
+
50
+ # Open a local file to save the download
51
+ with open(file_name, "ab") as f:
52
+ while True:
53
+ try:
54
+ # Send a GET request to the URL and save the response to the local file
55
+ response = requests.get(url, headers=headers, stream=True)
56
+
57
+ # Get the total size of the file
58
+ total_size = int(response.headers.get("Content-Length", 0))
59
+
60
+ # Update the total size of the progress bar if the `Content-Length` header is present
61
+ if total_size == 0:
62
+ total_size = downloaded_size
63
+ progress.total = total_size
64
+
65
+ # Write the response to the local file and update the progress bar
66
+ for chunk in response.iter_content(chunk_size=1024):
67
+ if chunk: # filter out keep-alive new chunks
68
+ f.write(chunk)
69
+ progress.update(len(chunk))
70
+
71
+ downloaded_size = os.path.getsize(file_name)
72
+ # Break out of the loop if the download is successful
73
+ break
74
+ except ConnectionError as e:
75
+ # Decrement the number of retries
76
+ max_retries -= 1
77
+
78
+ # If there are no more retries, raise the exception
79
+ if max_retries == 0:
80
+ raise e
81
+
82
+ # Wait for the specified delay before retrying
83
+ time.sleep(retry_delay)
84
+
85
+ # Close the progress bar
86
+ progress.close()
87
+ downloaded_size = os.path.getsize(file_name)
88
+ # Check if the download was successful
89
+ if downloaded_size >= total_size:
90
+ print(f"{file_name_display} successfully downloaded.")
91
+ break
92
+ else:
93
+ print(f"Error: File download failed. Retrying... {file_name_display}")
94
+
95
+ def make_new_folder(content_type, use_new_folder, model_name, lora_old):
96
+ if content_type == "Checkpoint":
97
+ folder = "models/Stable-diffusion"
98
+ new_folder = "models/Stable-diffusion/new"
99
+ elif content_type == "Hypernetwork":
100
+ folder = "models/hypernetworks"
101
+ new_folder = "models/hypernetworks/new"
102
+ elif content_type == "TextualInversion":
103
+ folder = "embeddings"
104
+ new_folder = "embeddings/new"
105
+ elif content_type == "AestheticGradient":
106
+ folder = "extensions/stable-diffusion-webui-aesthetic-gradients/aesthetic_embeddings"
107
+ new_folder = "extensions/stable-diffusion-webui-aesthetic-gradients/aesthetic_embeddings/new"
108
+ elif content_type == "VAE":
109
+ folder = "models/VAE"
110
+ new_folder = "models/VAE/new"
111
+ elif content_type == "LORA":
112
+ if lora_old:
113
+ folder = "extensions/sd-webui-additional-networks/models/lora"
114
+ new_folder = "extensions/sd-webui-additional-networks/models/lora/new"
115
+ else:
116
+ folder = "models/Lora"
117
+ new_folder = "models/Lora/new"
118
+ elif content_type == "LoCon":
119
+ if lora_old:
120
+ folder = "extensions/sd-webui-additional-networks/models/lora"
121
+ new_folder = "extensions/sd-webui-additional-networks/models/lora/new"
122
+ else:
123
+ folder = "models/Lora"
124
+ new_folder = "models/Lora/new"
125
+ if content_type == "TextualInversion" or content_type == "VAE" or \
126
+ content_type == "AestheticGradient":
127
+ if use_new_folder:
128
+ model_folder = new_folder
129
+ if not os.path.exists(new_folder):
130
+ os.makedirs(new_folder)
131
+
132
+ else:
133
+ model_folder = folder
134
+ if not os.path.exists(model_folder):
135
+ os.makedirs(model_folder)
136
+ else:
137
+ if use_new_folder:
138
+ model_folder = os.path.join(new_folder,model_name.replace(" ","_").replace("(","").replace(")","").replace("|","").replace(":","-"))
139
+ if not os.path.exists(new_folder):
140
+ os.makedirs(new_folder)
141
+ if not os.path.exists(model_folder):
142
+ os.makedirs(model_folder)
143
+
144
+ else:
145
+ model_folder = os.path.join(folder,model_name.replace(" ","_").replace("(","").replace(")","").replace("|","").replace(":","-"))
146
+ if not os.path.exists(model_folder):
147
+ os.makedirs(model_folder)
148
+ return model_folder
149
+
150
+ def download_file_thread(url, file_name, content_type, use_new_folder, model_name, lora_old):
151
+ model_folder = make_new_folder(content_type, use_new_folder, model_name, lora_old)
152
+
153
+ path_to_new_file = os.path.join(model_folder, file_name)
154
+
155
+ thread = threading.Thread(target=download_file, args=(url, path_to_new_file))
156
+
157
+ # Start the thread
158
+ thread.start()
159
+
160
+ def save_text_file(file_name, content_type, use_new_folder, trained_words, model_name, lora_old):
161
+ model_folder = make_new_folder(content_type, use_new_folder, model_name, lora_old)
162
+
163
+ path_to_new_file = os.path.join(model_folder, file_name.replace(".ckpt",".txt").replace(".safetensors",".txt").replace(".pt",".txt").replace(".yaml",".txt"))
164
+ if not os.path.exists(path_to_new_file):
165
+ with open(path_to_new_file, 'w') as f:
166
+ f.write(trained_words)
167
+ if os.path.getsize(path_to_new_file) == 0:
168
+ print("Current model doesn't have any trained tags")
169
+ else:
170
+ print("Trained tags saved as text file")
171
+
172
+ # Set the URL for the API endpoint
173
+ api_url = "https://civitai.com/api/v1/models?limit=50"
174
+ json_data = None
175
+
176
+ def api_to_data(content_type, sort_type, use_search_term, search_term=None):
177
+ if use_search_term and search_term:
178
+ search_term = search_term.replace(" ","%20")
179
+ return request_civit_api(f"{api_url}&types={content_type}&sort={sort_type}&query={search_term}")
180
+ else:
181
+ return request_civit_api(f"{api_url}&types={content_type}&sort={sort_type}")
182
+
183
+ def api_next_page(next_page_url=None):
184
+ global json_data
185
+ try: json_data['metadata']['nextPage']
186
+ except: return
187
+ if json_data['metadata']['nextPage'] is not None:
188
+ next_page_url = json_data['metadata']['nextPage']
189
+ if next_page_url is not None:
190
+ return request_civit_api(next_page_url)
191
+
192
+ def update_next_page(show_nsfw):
193
+ global json_data
194
+ json_data = api_next_page()
195
+ model_dict = {}
196
+ try: json_data['items']
197
+ except TypeError: return gr.Dropdown.update(choices=[], value=None)
198
+ if show_nsfw:
199
+ for item in json_data['items']:
200
+ model_dict[item['name']] = item['name']
201
+ else:
202
+ for item in json_data['items']:
203
+ temp_nsfw = item['nsfw']
204
+ if not temp_nsfw:
205
+ model_dict[item['name']] = item['name']
206
+ return gr.Dropdown.update(choices=[PLACEHOLDER] + [v for k, v in model_dict.items()], value=PLACEHOLDER), gr.Dropdown.update(choices=[], value=None)
207
+
208
+
209
+ def update_model_list(content_type, sort_type, use_search_term, search_term, show_nsfw):
210
+ global json_data
211
+ json_data = api_to_data(content_type, sort_type, use_search_term, search_term)
212
+ model_dict = {}
213
+ if show_nsfw:
214
+ for item in json_data['items']:
215
+ model_dict[item['name']] = item['name']
216
+ else:
217
+ for item in json_data['items']:
218
+ temp_nsfw = item['nsfw']
219
+ if not temp_nsfw:
220
+ model_dict[item['name']] = item['name']
221
+ return gr.Dropdown.update(choices=[PLACEHOLDER] + [v for k, v in model_dict.items()], value=PLACEHOLDER), gr.Dropdown.update(choices=[], value=None)
222
+
223
+ def update_model_versions(model_name=None):
224
+ if model_name is not None and model_name != PLACEHOLDER:
225
+ global json_data
226
+ versions_dict = {}
227
+ for item in json_data['items']:
228
+ if item['name'] == model_name:
229
+
230
+ for model in item['modelVersions']:
231
+ versions_dict[model['name']] = item["name"]
232
+ return gr.Dropdown.update(choices=[PLACEHOLDER] + [k + ' - ' + v for k, v in versions_dict.items()], value=PLACEHOLDER)
233
+ else:
234
+ return gr.Dropdown.update(choices=[], value=None)
235
+
236
+ def update_dl_url(model_name=None, model_version=None, model_filename=None):
237
+ if model_filename:
238
+ global json_data
239
+ dl_dict = {}
240
+ dl_url = None
241
+ model_version = model_version.replace(f' - {model_name}','').strip()
242
+ for item in json_data['items']:
243
+ if item['name'] == model_name:
244
+ for model in item['modelVersions']:
245
+ if model['name'] == model_version:
246
+ for file in model['files']:
247
+ if file['name'] == model_filename:
248
+ dl_url = file['downloadUrl']
249
+ return gr.Textbox.update(value=dl_url)
250
+ else:
251
+ return gr.Textbox.update(value=None)
252
+
253
+ def update_model_info(model_name=None, model_version=None):
254
+ if model_name and model_version and model_name != PLACEHOLDER and model_version != PLACEHOLDER:
255
+ model_version = model_version.replace(f' - {model_name}','').strip()
256
+ global json_data
257
+ output_html = ""
258
+ output_training = ""
259
+ img_html = ""
260
+ model_desc = ""
261
+ dl_dict = {}
262
+ for item in json_data['items']:
263
+ if item['name'] == model_name:
264
+ model_uploader = item['creator']['username']
265
+ if item['description']:
266
+ model_desc = item['description']
267
+ for model in item['modelVersions']:
268
+ if model['name'] == model_version:
269
+ if model['trainedWords']:
270
+ output_training = ", ".join(model['trainedWords'])
271
+
272
+ for file in model['files']:
273
+ dl_dict[file['name']] = file['downloadUrl']
274
+
275
+ model_url = model['downloadUrl']
276
+ #model_filename = model['files']['name']
277
+
278
+ img_html = '<HEAD><style>img { display: inline-block; }</style></HEAD><div class="column">'
279
+ for pic in model['images']:
280
+ img_html = img_html + f'<img src={pic["url"]} width=400px></img>'
281
+ img_html = img_html + '</div>'
282
+ output_html = f"<p><b>Model:</b> {model_name}<br><b>Version:</b> {model_version}<br><b>Uploaded by:</b> {model_uploader}<br><br><a href={model_url}><b>Download Here</b></a></p><br><br>{model_desc}<br><div align=center>{img_html}</div>"
283
+
284
+ return gr.HTML.update(value=output_html), gr.Textbox.update(value=output_training), gr.Dropdown.update(choices=[PLACEHOLDER] + [k for k, v in dl_dict.items()], value=PLACEHOLDER)
285
+ else:
286
+ return gr.HTML.update(value=None), gr.Textbox.update(value=None), gr.Dropdown.update(choices=[], value=None)
287
+
288
+
289
+ def request_civit_api(api_url=None):
290
+ # Make a GET request to the API
291
+ response = requests.get(api_url)
292
+
293
+ # Check the status code of the response
294
+ if response.status_code != 200:
295
+ print("Request failed with status code: {}".format(response.status_code))
296
+ exit()
297
+
298
+ data = json.loads(response.text)
299
+ return data
300
+
301
+ #from https://github.com/thetrebor/sd-civitai-browser/blob/add-download-images/scripts/civitai-api.py
302
+ def update_everything(list_models, list_versions, model_filename, dl_url):
303
+ (a, d, f) = update_model_info(list_models, list_versions)
304
+ dl_url = update_dl_url(list_models, list_versions, f['value'])
305
+ return (a, d, f, list_versions, list_models, dl_url)
306
+
307
+ def save_image_files(preview_image_html, model_filename, content_type, use_new_folder, list_models, lora_old):
308
+ print("Save Images Clicked")
309
+ model_folder = make_new_folder(content_type, use_new_folder, list_models, lora_old)
310
+
311
+ img_urls = re.findall(r'src=[\'"]?([^\'" >]+)', preview_image_html)
312
+
313
+ name = os.path.splitext(model_filename)[0]
314
+ assert(name != "<no select>"), "Please select a Model Filename to download"
315
+ current_directory = os.getcwd()
316
+ while os.path.basename(current_directory) != "stable-diffusion-webui":
317
+ current_directory = os.path.dirname(current_directory)
318
+ new_model_folder = os.path.join(current_directory, model_folder)
319
+ # new_model_folder = os.path.join(current_directory,list_models.replace(" ","_").replace("(","").replace(")","").replace("|","").replace(":","-"))
320
+
321
+ headers = {"User-Agent": str(ua.random)}
322
+ print(img_urls)
323
+
324
+ for i, img_url in enumerate(img_urls):
325
+ filename = f'{name}_{i}.png'
326
+ # img_url = img_url.replace("https", "http").replace("=","%3D")
327
+
328
+ print(f'Downloading {img_url} to {filename}')
329
+ try:
330
+ with requests.get(img_url, headers) as url:
331
+ with open(os.path.join(new_model_folder, filename), 'wb') as f:
332
+ with Image.open(BytesIO(url.content)) as save_me:
333
+ save_me.save(f)
334
+ print(f'Downloaded {img_url}')
335
+ # with urllib.request.urlretrieve(img_url, os.path.join(model_folder, filename)) as dl:
336
+
337
+ except urllib.error.URLError as e:
338
+ print(f'Error: {e.reason}')
339
+
340
+ finally:
341
+ print("Images downloaded.")
342
+
343
+ if os.path.exists(os.path.join(new_model_folder, f'{name}_0.png')):
344
+ with open(os.path.join(new_model_folder, f'{name}_0.png'), 'rb') as f_in:
345
+ with open(os.path.join(new_model_folder, f'{name}.png'), 'wb') as f_out:
346
+ f_out.write(f_in.read())
347
+
348
+ def on_ui_tabs():
349
+ with gr.Blocks() as civitai_interface:
350
+ with gr.Row():
351
+ with gr.Column(scale=2):
352
+ content_type = gr.Radio(label='Content type:', choices=["Checkpoint","Hypernetwork","TextualInversion","AestheticGradient", "VAE", "LORA", "LoCon"], value="Checkpoint", type="value")
353
+ with gr.Column(scale=2):
354
+ sort_type = gr.Radio(label='Sort List by:', choices=["Newest","Most Downloaded","Highest Rated","Most Liked"], value="Newest", type="value")
355
+ with gr.Column(scale=1):
356
+ show_nsfw = gr.Checkbox(label="Show NSFW", value=True)
357
+ with gr.Row():
358
+ use_search_term = gr.Checkbox(label="Search by term?", value=False)
359
+ search_term = gr.Textbox(label="Search Term", interactive=True, lines=1)
360
+ with gr.Row():
361
+ get_list_from_api = gr.Button(label="Get List", value="Get List")
362
+ get_next_page = gr.Button(value="Next Page")
363
+ with gr.Row():
364
+ list_models = gr.Dropdown(label="Model", choices=[], interactive=True, elem_id="quicksettings", value=None)
365
+ list_versions = gr.Dropdown(label="Version", choices=[], interactive=True, elem_id="quicksettings", value=None)
366
+ with gr.Row():
367
+ txt_list = ""
368
+ dummy = gr.Textbox(label='Trained Tags (if any)', value=f'{txt_list}', interactive=True, lines=1)
369
+ model_filename = gr.Dropdown(label="Model Filename", choices=[], interactive=True, value=None)
370
+ dl_url = gr.Textbox(label="Download Url", interactive=False, value=None)
371
+ with gr.Row():
372
+ update_info = gr.Button(value='1st - Get Model Info')
373
+ save_text = gr.Button(value="2nd - Save Trained Tags as Text")
374
+ save_images = gr.Button(value="3rd - Save Images")
375
+ download_model = gr.Button(value="4th - Download Model")
376
+ with gr.Row():
377
+ save_model_in_new = gr.Checkbox(label="Save Model to new folder", value=False)
378
+ old_lora = gr.Checkbox(label="Save LoRA to additional-networks", value=True)
379
+ with gr.Row():
380
+ preview_image_html = gr.HTML()
381
+ save_text.click(
382
+ fn=save_text_file,
383
+ inputs=[
384
+ model_filename,
385
+ content_type,
386
+ save_model_in_new,
387
+ dummy,
388
+ list_models,
389
+ old_lora,
390
+ ],
391
+ outputs=[]
392
+ )
393
+ save_images.click(
394
+ fn=save_image_files,
395
+ inputs=[
396
+ preview_image_html,
397
+ model_filename,
398
+ content_type,
399
+ save_model_in_new,
400
+ list_models,
401
+ old_lora,
402
+ ],
403
+ outputs=[]
404
+ )
405
+ download_model.click(
406
+ fn=download_file_thread,
407
+ inputs=[
408
+ dl_url,
409
+ model_filename,
410
+ content_type,
411
+ save_model_in_new,
412
+ list_models,
413
+ old_lora,
414
+ ],
415
+ outputs=[]
416
+ )
417
+ get_list_from_api.click(
418
+ fn=update_model_list,
419
+ inputs=[
420
+ content_type,
421
+ sort_type,
422
+ use_search_term,
423
+ search_term,
424
+ show_nsfw,
425
+ ],
426
+ outputs=[
427
+ list_models,
428
+ list_versions,
429
+ ]
430
+ )
431
+ update_info.click(
432
+ fn=update_everything,
433
+ #fn=update_model_info,
434
+ inputs=[
435
+ list_models,
436
+ list_versions,
437
+ model_filename,
438
+ dl_url
439
+ ],
440
+ outputs=[
441
+ preview_image_html,
442
+ dummy,
443
+ model_filename,
444
+ list_versions,
445
+ list_models,
446
+ dl_url
447
+ ]
448
+ )
449
+ list_models.change(
450
+ fn=update_model_versions,
451
+ inputs=[
452
+ list_models,
453
+ ],
454
+ outputs=[
455
+ list_versions,
456
+ ]
457
+ )
458
+
459
+ list_versions.change(
460
+ fn=update_model_info,
461
+ inputs=[
462
+ list_models,
463
+ list_versions,
464
+ ],
465
+ outputs=[
466
+ preview_image_html,
467
+ dummy,
468
+ model_filename,
469
+ ]
470
+ )
471
+ model_filename.change(
472
+ fn=update_dl_url,
473
+ inputs=[list_models, list_versions, model_filename,],
474
+ outputs=[dl_url,]
475
+ )
476
+ get_next_page.click(
477
+ fn=update_next_page,
478
+ inputs=[
479
+ show_nsfw,
480
+ ],
481
+ outputs=[
482
+ list_models,
483
+ list_versions,
484
+ ]
485
+ )
486
+
487
+ return (civitai_interface, "CivitAi", "civitai_interface"),
488
+
489
+ script_callbacks.on_ui_tabs(on_ui_tabs)