cv_test / main.py
tktkdrrrrrrrrrrr's picture
Update main.py
b60e664 verified
import asyncio
import datetime
import json
import logging
import os
import re
import shutil
import subprocess
import time
from typing import Optional
import requests
from bs4 import BeautifulSoup
from fake_useragent import UserAgent
from fastapi import FastAPI
from huggingface_hub import HfApi, create_repo, hf_hub_download, login
# ロギングの設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Config:
"""設定用のクラス"""
HUGGINGFACE_API_KEY = os.environ["HUGGINGFACE_API_KEY"]
CIVITAI_API_TOKEN = os.environ["CIVITAI_API_TOKEN"]
LOG_FILE = "civitai_backup.log"
LIST_FILE = "model_list.log"
REPO_IDS = {
"log": "tktkdrrrrrrrrrrr/CivitAI_log_test",
"model_list": "tktkdrrrrrrrrrrr/CivitAI_model_info_test",
"current": ""
}
URLS = {
"latest": "https://civitai.com/api/v1/models?sort=Newest",
"modelPage": "https://civitai.com/models/",
"modelId": "https://civitai.com/api/v1/models/",
"modelVersionId": "https://civitai.com/api/v1/model-versions/",
"hash": "https://civitai.com/api/v1/model-versions/by-hash/"
}
JST = datetime.timezone(datetime.timedelta(hours=9))
UA = UserAgent()
HEADERS = {
'Authorization': f'Bearer {CIVITAI_API_TOKEN}',
'User-Agent': UA.random,
"Content-Type": "application/json"
}
class CivitAICrawler:
"""CivitAIからモデルをダウンロードし、Hugging Faceにアップロードするクラス"""
def __init__(self, config: Config):
self.config = config
self.api = HfApi()
self.app = FastAPI()
self.repo_ids = self.config.REPO_IDS.copy()
self.jst = self.config.JST
self.setup_routes()
def setup_routes(self):
"""FastAPIのルーティングを設定する。"""
@self.app.get("/")
def read_root():
now = str(datetime.datetime.now(self.jst))
description = f"""
CivitAIを定期的に周回し新規モデルを osanpo/CivitAI_Auto10 にバックアップするspaceです。
モデルページ名とバックアップURLの紐づきはhttps://huggingface.co/{self.repo_ids['model_list']}/blob/main/model_list.logからどうぞ
たまに覗いてもらえると動き続けると思います。
再起動が必要になっている場合はRestartボタンを押してもらえると助かります。
Status: {now} + currently running :D
"""
return description
@self.app.on_event("startup")
async def startup_event():
asyncio.create_task(self.crawl())
@staticmethod
def get_filename_from_cd(content_disposition: Optional[str], default_name: str) -> str:
"""Content-Dispositionヘッダーからファイル名を取得する。"""
if content_disposition:
parts = content_disposition.split(';')
for part in parts:
if "filename=" in part:
return part.split("=")[1].strip().strip('"')
return default_name
def download_file(self, url: str, destination_folder: str, default_name: str):
"""指定されたURLからファイルをダウンロードし、指定されたフォルダに保存する。"""
try:
response = requests.get(url, headers=self.config.HEADERS, stream=True)
response.raise_for_status() # ここでリクエストエラー(404など)が発生した場合、例外が発生
except requests.RequestException as e:
logger.error(f"Failed to download file from {url}: {e}")
return # エラー発生時は関数を終了し、空ファイルが作成されないようにする
filename = self.get_filename_from_cd(response.headers.get('content-disposition'), default_name)
file_path = os.path.join(destination_folder, filename)
# ダウンロードとファイル保存処理
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
logger.info(f"Download completed: {file_path}")
def get_model_info(self, model_id: str) -> dict:
"""モデルの情報を取得する。"""
try:
response = requests.get(self.config.URLS["modelId"] + str(model_id), headers=self.config.HEADERS)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Failed to retrieve model info for ID {model_id}: {e}")
def download_model(self, model_versions: list, folder: str, existing_old_version_files: list = []):
"""モデルのバージョンをダウンロードする。"""
# 最新バージョンのダウンロード
latest_version = model_versions[0]
latest_files = latest_version["files"]
for file_info in latest_files:
download_url = file_info["downloadUrl"]
file_name = file_info["name"]
login_detected_count = 0 # 各ファイルごとにリセット
while login_detected_count < 5:
try:
self.download_file(download_url, folder, file_name)
except Exception as e:
logger.error(f"Exception occurred while downloading {file_name}: {e}")
login_detected_count += 1
continue
if "login" in os.listdir(folder):
login_detected_count += 1
logger.warning(f"'login' file found. Will try again. ({login_detected_count}/5)")
os.remove(os.path.join(folder, "login"))
else:
logger.info(f"Successfully downloaded {file_name}")
break
if login_detected_count >= 5:
# ダウンロード失敗を示すダミーファイルを作成
dummy_file_name = f"{file_name}.download_failed"
dummy_file_path = os.path.join(folder, dummy_file_name)
try:
with open(dummy_file_path, "w") as f:
f.write("Download failed after 5 attempts.")
logger.error(f"Failed to download {file_name}. Created dummy file {dummy_file_name}. URL: {download_url}")
except Exception as e:
logger.error(f"Failed to create dummy file for {file_name}: {e}")
# 次のファイルのダウンロードを試行
continue
# 古いバージョンのダウンロードとアップロード
if len(model_versions) > 1:
old_versions_folder = os.path.join(folder, "old_versions")
os.makedirs(old_versions_folder, exist_ok=True)
for version in model_versions[1:]:
# 各バージョンのファイルを一つずつ処理
for file_info in version["files"]:
file_name = file_info["name"]
if file_name in existing_old_version_files:
logger.info(f"Skipping download of existing old version file: {file_name}")
continue
download_url = file_info["downloadUrl"]
local_file_path = os.path.join(old_versions_folder, file_name)
login_detected_count = 0 # 各ファイルごとにリセット
while login_detected_count < 5:
try:
self.download_file(download_url, old_versions_folder, file_name)
except Exception as e:
logger.error(f"Exception occurred while downloading {file_name}: {e}")
login_detected_count += 1
continue
if "login" in os.listdir(old_versions_folder):
login_detected_count += 1
logger.warning(f"'login' file found while downloading {file_name}. Will try again. ({login_detected_count}/5)")
os.remove(os.path.join(old_versions_folder, "login"))
else:
logger.info(f"Successfully downloaded {file_name}")
break
if login_detected_count >= 5:
# ダウンロード失敗を示すダミーファイルを作成
dummy_file_name = f"{file_name}.download_failed"
dummy_file_path = os.path.join(old_versions_folder, dummy_file_name)
try:
with open(dummy_file_path, "w") as f:
f.write("Download failed after 5 attempts.")
logger.error(f"Failed to download {file_name}. Created dummy file {dummy_file_name}. URL: {download_url}")
except Exception as e:
logger.error(f"Failed to create dummy file for {file_name}: {e}")
# 次のファイルのダウンロードを試行
continue
# ダウンロードが成功した場合のみアップロード処理を行う
try:
# ダウンロードしたファイルをアップロード
path_in_repo = os.path.join(os.path.basename(folder), "old_versions", file_name)
self.upload_file(
file_path=local_file_path,
path_in_repo=path_in_repo
)
# アップロードが成功したらローカルファイルを削除
os.remove(local_file_path)
logger.info(f"Deleted local file {local_file_path} after successful upload.")
except Exception as e:
logger.error(f"Failed to upload and delete file {local_file_path}: {e}")
def download_images(self, model_versions: list, folder: str):
"""モデルの画像をダウンロードし、指定されたフォルダに保存する。"""
images_folder = os.path.join(folder, "images")
os.makedirs(images_folder, exist_ok=True)
images = []
# 画像URLの取得
for version in model_versions:
for img in version.get("images", []):
image_url = img["url"]
images.append(image_url)
# 画像のダウンロードと保存
for image_url in images:
image_name = image_url.split("/")[-1]
try:
response = requests.get(image_url)
response.raise_for_status()
with open(os.path.join(images_folder, f"{image_name}.png"), "wb") as file:
file.write(response.content)
except requests.RequestException as e:
logger.error(f"Error downloading image {image_url}: {e}")
# 画像フォルダをパスワード付きZIPファイルに圧縮
try:
# 作業ディレクトリを一時的に変更
original_cwd = os.getcwd()
os.chdir(folder)
# 'zip' コマンドを実行
subprocess.run(
['zip', '-e', '--password=osanpo', 'images.zip', '-r', 'images'],
check=True
)
logger.info(f"Images compressed and saved to {os.path.join(folder, 'images.zip')}")
except subprocess.CalledProcessError as e:
logger.error(f"Error creating zip file: {e}")
finally:
# 作業ディレクトリを元に戻す
os.chdir(original_cwd)
# ZIP化した後の元の画像フォルダを削除
if os.path.exists(images_folder):
shutil.rmtree(images_folder)
def save_html_content(self, url: str, folder: str):
"""指定されたURLからHTMLコンテンツを取得し、保存する。"""
try:
response = requests.get(url)
response.raise_for_status()
html_path = os.path.join(folder, f"{folder}.html")
with open(html_path, 'w', encoding='utf-8') as file:
file.write(response.text)
except Exception as e:
logger.error(f"Error saving HTML content for URL {url}: {e}")
@staticmethod
def save_model_info(model_info: dict, folder: str):
"""モデル情報(json)の保存"""
with open(os.path.join(folder, "model_info.json"), "w") as file:
json.dump(model_info, file, indent=2)
@staticmethod
def increment_repo_name(repo_id: str) -> str:
"""リポジトリ名の末尾の数字をインクリメントする。"""
match = re.search(r'(\d+)$', repo_id)
if match:
number = int(match.group(1)) + 1
new_repo_id = re.sub(r'\d+$', str(number), repo_id)
else:
new_repo_id = f"{repo_id}1"
return new_repo_id
def upload_file(self, file_path: str, repo_id: Optional[str] = None, path_in_repo: Optional[str] = None):
"""ファイルをリポジトリにアップロードする。"""
if repo_id is None:
repo_id = self.repo_ids['current']
if path_in_repo is None:
path_in_repo = os.path.basename(file_path)
max_retries = 5 # 最大試行回数
attempt = 0
while attempt < max_retries:
try:
self.api.upload_file(
path_or_fileobj=file_path,
repo_id=repo_id,
path_in_repo=path_in_repo
)
logger.info(f"Uploaded file {file_path} to repository {self.repo_ids['current']} at {path_in_repo}.")
return # 成功したらメソッドを終了
except Exception as e:
attempt += 1
error_message = str(e)
if "over the limit of 100000 files" in error_message:
logger.warning("Repository file limit exceeded, creating a new repository.")
self.repo_ids['current'] = self.increment_repo_name(self.repo_ids['current'])
self.api.create_repo(repo_id=self.repo_ids['current'], private=True)
# リポジトリを変更したのでリトライ回数をリセットする
attempt = 0
continue # 新しいリポジトリで再試行
elif "you can retry this action in about 1 hour" in error_message:
logger.warning("Encountered 'retry in 1 hour' error. Waiting for 1 hour before retrying...")
time.sleep(3600) # 1時間(3600秒)待機してリトライ
attempt -= 1 # このエラーの場合、リトライ回数はカウントしない
else:
if attempt < max_retries:
logger.warning(f"Failed to upload file {file_path}, attempt {attempt}/{max_retries}. Retrying...")
else:
logger.error(f"Failed to upload file {file_path} after {max_retries} attempts: {e}")
raise
def upload_folder(self, folder_path: str, path_in_repo: Optional[str] = None):
"""フォルダをリポジトリにアップロードする。"""
if path_in_repo is None:
path_in_repo = os.path.basename(folder_path)
max_retries = 5 # 最大試行回数
attempt = 0
while attempt < max_retries:
try:
self.api.upload_folder(
folder_path=folder_path,
repo_id=self.repo_ids['current'],
path_in_repo=path_in_repo
)
logger.info(f"Uploaded folder {folder_path} to repository {self.repo_ids['current']} at {path_in_repo}.")
return # 成功したらメソッドを終了
except Exception as e:
attempt += 1
error_message = str(e)
if "over the limit of 100000 files" in error_message:
logger.warning("Repository file limit exceeded, creating a new repository.")
self.repo_ids['current'] = self.increment_repo_name(self.repo_ids['current'])
self.api.create_repo(repo_id=self.repo_ids['current'], private=True)
attempt = 0 # 新しいリポジトリで再試行
continue # リトライ
elif "you can retry this action in about 1 hour" in error_message:
logger.warning("Encountered 'retry in 1 hour' error. Waiting for 1 hour before retrying...")
time.sleep(3600) # 1時間(3600秒)待機してリトライ
attempt -= 1 # このエラーの場合、リトライ回数はカウントしない
else:
if attempt < max_retries:
logger.warning(f"Failed to upload folder {folder_path}, attempt {attempt}/{max_retries}. Retrying...")
else:
logger.error(f"Failed to upload folder {folder_path} after {max_retries} attempts: {e}")
raise
def read_model_list(self):
"""モデルリストを読み込む。"""
model_list = {}
try:
with open(self.config.LIST_FILE, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
# 各行は "{modelpage_name}: {model_hf_url}" の形式
parts = line.split(": ", 1)
if len(parts) == 2:
modelpage_name, model_hf_url = parts
model_list[model_hf_url] = modelpage_name
return model_list
except Exception as e:
logger.error(f"Failed to read model list: {e}")
return {}
def get_repo_info(self, repo_id):
"""リポジトリの情報を取得する。"""
try:
repo_info = self.api.repo_info(repo_id=repo_id, files_metadata=True)
# ファイル名を抽出
file_paths = [sibling.rfilename for sibling in repo_info.siblings]
return file_paths
except Exception as e:
logger.error(f"Failed to get repo info for {repo_id}: {e}")
return []
def process_model(self, model_url: str):
"""指定されたモデルURLを処理する関数。"""
try:
# モデルIDの取得
model_id = model_url.rstrip("/").split("/")[-1]
# モデル情報の取得
model_info = self.get_model_info(model_id)
# 'type'が'{"Model"}'のファイル名を取得
latest_version = model_info.get("modelVersions", [])[0]
model_file = next(
(file for file in latest_version["files"] if file.get('type') == 'Model'),
None
)
if model_file:
latest_filename = model_file['name']
folder = os.path.splitext(latest_filename)[0]
else:
# 'Model'タイプがない場合、最初のファイル名を使用
first_file = latest_version["files"][0]
latest_filename = first_file['name']
folder = os.path.splitext(latest_filename)[0]
logger.warning(f"No 'Model' type file found for model ID {model_id}. Using first file's name.")
# ダウンロード用フォルダの準備
os.makedirs(folder, exist_ok=True)
# モデルの保存先URLの生成
model_hf_url = f"https://huggingface.co/{self.repo_ids['current']}/tree/main/{folder}"
# モデルリストの読み込み
model_list = self.read_model_list()
if model_hf_url in model_list:
# 一致するリポジトリの情報を取得
repo_id = self.repo_ids['current']
repo_files = self.get_repo_info(repo_id)
# そのリポジトリ内の古いバージョン(old_versions)のリストを取得
old_versions_files = [f for f in repo_files if f.startswith(f"{folder}/old_versions/")]
existing_old_version_files = [os.path.basename(f) for f in old_versions_files]
else:
existing_old_version_files = []
# モデルファイルのダウンロード
self.download_model(model_info["modelVersions"], folder, existing_old_version_files)
# モデル画像のダウンロード
self.download_images(model_info["modelVersions"], folder)
# HTMLコンテンツの保存
self.save_html_content(model_url, folder)
# モデル情報(json)の保存
self.save_model_info(model_info, folder)
# モデルのアップロード
self.upload_folder(folder)
# モデルリストの更新
modelpage_name = model_info.get("name", "Unnamed Model")
model_hf_url = f"https://huggingface.co/{self.repo_ids['current']}/tree/main/{folder}"
with open(self.config.LIST_FILE, "a", encoding="utf-8") as f:
f.write(f"{modelpage_name}: {model_hf_url}\n")
# フォルダの削除
if os.path.exists(folder):
shutil.rmtree(folder)
except Exception as e:
logger.error(f"Unexpected error processing model ({model_url}): {e}")
async def crawl(self):
"""モデルを定期的にチェックし、更新を行う。"""
while True:
try:
login(token=self.config.HUGGINGFACE_API_KEY, add_to_git_credential=True)
# model_list.logのダウンロード
model_list_path = hf_hub_download(repo_id=self.repo_ids['model_list'], filename=self.config.LIST_FILE)
shutil.copyfile(model_list_path, f"./{self.config.LIST_FILE}")
# ログファイルのダウンロード
local_file_path = hf_hub_download(repo_id=self.repo_ids["log"], filename=self.config.LOG_FILE)
shutil.copyfile(local_file_path, f"./{self.config.LOG_FILE}")
# ログの読み込み
with open(self.config.LOG_FILE, "r", encoding="utf-8") as file:
lines = file.read().splitlines()
old_models = json.loads(lines[0]) if len(lines) > 0 else []
self.repo_ids["current"] = lines[1] if len(lines) > 1 else ""
# 新着モデルの取得
response = requests.get(self.config.URLS["latest"], headers=self.config.HEADERS)
response.raise_for_status()
latest_models = response.json().get("items", [])
latest_model_ids = [item.get("id") for item in latest_models if "id" in item]
# 増分の確認
new_models = list(set(latest_model_ids) - set(old_models))
if new_models:
logger.info(f"New models found: {new_models}")
model_id = new_models[0]
for attempt in range(1, 6):
try:
self.process_model(f"{self.config.URLS['modelId']}{model_id}")
break
except Exception as e:
logger.error(f"Failed to process model ID {model_id} (Attempt {attempt}/5): {e}")
if attempt == 5:
logger.error(f"Skipping model ID {model_id} after 5 failed attempts.")
else:
await asyncio.sleep(2)
else:
# ログファイルを最新のモデルIDで上書き
with open(self.config.LOG_FILE, "w", encoding="utf-8") as f:
f.write(json.dumps(latest_model_ids) + "\n")
f.write(f"{self.repo_ids['current']}\n")
logger.info(f"Updated log file: {self.config.LOG_FILE}")
# ログファイルをリポジトリにアップロード
self.upload_file(
file_path=self.config.LOG_FILE,
repo_id=self.repo_ids["log"],
path_in_repo=self.config.LOG_FILE
)
logger.info("Uploaded log file to repository.")
logger.info("No new models found.")
await asyncio.sleep(60)
continue
# 古いモデルリストに追加
old_models.append(model_id)
# ログファイルの更新
with open(self.config.LOG_FILE, "w", encoding="utf-8") as f:
f.write(json.dumps(old_models) + "\n")
f.write(f"{self.repo_ids['current']}\n")
logger.info(f"Updated log file with new model ID: {model_id}")
# ログとモデルリストのアップロード
self.upload_file(
file_path=self.config.LOG_FILE,
repo_id=self.repo_ids["log"],
path_in_repo=self.config.LOG_FILE
)
self.upload_file(
file_path=self.config.LIST_FILE,
repo_id=self.repo_ids["model_list"],
path_in_repo=self.config.LIST_FILE
)
except Exception as e:
logger.error(f"Error during crawling: {e}")
await asyncio.sleep(300)
# モジュールレベルでFastAPIのアプリケーションを公開
config = Config()
crawler = CivitAICrawler(config)
app = crawler.app # Uvicornが参照できるように、モジュールレベルでappを定義