dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
from contextlib import closing
from typing import Dict, List
from scripts.iib.db.datamodel import Image as DbImg, Tag, ImageTag, DataBase, Folder
import os
from scripts.iib.tool import (
is_valid_media_path,
get_modified_date,
get_video_type,
is_dev,
get_modified_date,
is_image_file,
case_insensitive_get
)
from scripts.iib.parsers.model import ImageGenerationInfo, ImageGenerationParams
from scripts.iib.logger import logger
from scripts.iib.parsers.index import parse_image_info
from scripts.iib.plugin import plugin_inst_map
# 定义一个函数来获取图片文件的EXIF数据
def get_exif_data(file_path):
if get_video_type(file_path):
return ImageGenerationInfo()
try:
return parse_image_info(file_path)
except Exception as e:
if is_dev:
logger.error("get_exif_data %s", e)
return ImageGenerationInfo()
def update_image_data(search_dirs: List[str], is_rebuild = False):
conn = DataBase.get_conn()
tag_incr_count_rec: Dict[int, int] = {}
if is_rebuild:
Folder.remove_all(conn)
def safe_save_img_tag(img_tag: ImageTag):
tag_incr_count_rec[img_tag.tag_id] = (
tag_incr_count_rec.get(img_tag.tag_id, 0) + 1
)
img_tag.save_or_ignore(conn) # 原先用来处理一些意外,但是写的正确完全没问题,去掉了try catch
# 递归处理每个文件夹
def process_folder(folder_path: str):
if not Folder.check_need_update(conn, folder_path):
return
print(f"Processing folder: {folder_path}")
for filename in os.listdir(folder_path):
file_path = os.path.normpath(os.path.join(folder_path, filename))
try:
if os.path.isdir(file_path):
process_folder(file_path)
elif is_valid_media_path(file_path):
build_single_img_idx(conn, file_path, is_rebuild, safe_save_img_tag)
# neg暂时跳过感觉个没人会搜索这个
except Exception as e:
logger.error("Tag generation failed. Skipping this file. file:%s error: %s", file_path, e)
# 提交对数据库的更改
Folder.update_modified_date_or_create(conn, folder_path)
conn.commit()
for dir in search_dirs:
process_folder(dir)
conn.commit()
for tag_id in tag_incr_count_rec:
tag = Tag.get(conn, tag_id)
tag.count += tag_incr_count_rec[tag_id]
tag.save(conn)
conn.commit()
def add_image_data_single(file_path):
conn = DataBase.get_conn()
tag_incr_count_rec: Dict[int, int] = {}
def safe_save_img_tag(img_tag: ImageTag):
tag_incr_count_rec[img_tag.tag_id] = (
tag_incr_count_rec.get(img_tag.tag_id, 0) + 1
)
img_tag.save_or_ignore(conn)
file_path = os.path.normpath(file_path)
try:
if not is_valid_media_path(file_path):
return
build_single_img_idx(conn, file_path, False, safe_save_img_tag)
# neg暂时跳过感觉个没人会搜索这个
except Exception as e:
logger.error("Tag generation failed. Skipping this file. file:%s error: %s", file_path, e)
conn.commit()
for tag_id in tag_incr_count_rec:
tag = Tag.get(conn, tag_id)
tag.count += tag_incr_count_rec[tag_id]
tag.save(conn)
conn.commit()
def rebuild_image_index(search_dirs: List[str]):
conn = DataBase.get_conn()
with closing(conn.cursor()) as cur:
cur.execute(
"""DELETE FROM image_tag
WHERE image_tag.tag_id IN (
SELECT tag.id FROM tag WHERE tag.type <> 'custom'
)
"""
)
cur.execute("""DELETE FROM tag WHERE tag.type <> 'custom'""")
conn.commit()
update_image_data(search_dirs=search_dirs, is_rebuild=True)
def get_extra_meta_keys_from_plugins(source_identifier: str):
try:
plugin = plugin_inst_map.get(source_identifier)
if plugin:
return plugin.extra_convert_to_tag_meta_keys
except Exception as e:
logger.error("get_extra_meta_keys_from_plugins %s", e)
return []
def build_single_img_idx(conn, file_path, is_rebuild, safe_save_img_tag):
img = DbImg.get(conn, file_path)
parsed_params = None
if is_rebuild:
info = get_exif_data(file_path)
parsed_params = info.params
if not img:
img = DbImg(
file_path,
info.raw_info,
os.path.getsize(file_path),
get_modified_date(file_path),
)
img.save(conn)
else:
if img: # 已存在的跳过
if img.date == get_modified_date(img.path):
return
else:
DbImg.safe_batch_remove(conn=conn, image_ids=[img.id])
info = get_exif_data(file_path)
parsed_params = info.params
img = DbImg(
file_path,
info.raw_info,
os.path.getsize(file_path),
get_modified_date(file_path),
)
img.save(conn)
if not parsed_params:
return
meta = parsed_params.meta
lora = parsed_params.extra.get("lora", [])
lyco = parsed_params.extra.get("lyco", [])
pos = parsed_params.pos_prompt
size_tag = Tag.get_or_create(
conn,
str(meta.get("Size-1", 0)) + " * " + str(meta.get("Size-2", 0)),
type="size",
)
safe_save_img_tag(ImageTag(img.id, size_tag.id))
media_type_tag = Tag.get_or_create(conn, "Image" if is_image_file(file_path) else "Video", 'Media Type')
safe_save_img_tag(ImageTag(img.id, media_type_tag.id))
keys = [
"Model",
"Sampler",
"Source Identifier",
"Postprocess upscale by",
"Postprocess upscaler",
"Size",
"Refiner",
"Hires upscaler"
]
keys += get_extra_meta_keys_from_plugins(meta.get("Source Identifier", ""))
for k in keys:
v = case_insensitive_get(meta, k)
if not v:
continue
tag = Tag.get_or_create(conn, str(v), k)
safe_save_img_tag(ImageTag(img.id, tag.id))
if "Hires upscaler" == k:
tag = Tag.get_or_create(conn, 'Hires All', k)
safe_save_img_tag(ImageTag(img.id, tag.id))
elif "Refiner" == k:
tag = Tag.get_or_create(conn, 'Refiner All', k)
safe_save_img_tag(ImageTag(img.id, tag.id))
for i in lora:
tag = Tag.get_or_create(conn, i["name"], "lora")
safe_save_img_tag(ImageTag(img.id, tag.id))
for i in lyco:
tag = Tag.get_or_create(conn, i["name"], "lyco")
safe_save_img_tag(ImageTag(img.id, tag.id))
for k in pos:
tag = Tag.get_or_create(conn, k, "pos")
safe_save_img_tag(ImageTag(img.id, tag.id))