GetModelLora / app.py
tomo2chin2's picture
Update app.py
b401ca4 verified
import os
import json
import re
import requests
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
# 環境変数から取得
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_REPO = os.environ.get("MODEL_REPO", None)
LORA_REPO = os.environ.get("LORA_REPO", None)
EMBE_REPO = os.environ.get("EMBE_REPO", None) # embedding用リポジトリ
CIVITAI_TOKEN = os.environ.get("CIVITAI_TOKEN", None) # Civitai用トークン
api = HfApi()
def download_and_upload(model_id, repo_type):
print(f"モデルID: {model_id}, リポジトリタイプ: {repo_type}")
if not model_id or not repo_type:
return "モデルIDまたはタイプが未指定です。"
headers = {}
if CIVITAI_TOKEN:
headers["Authorization"] = f"Bearer {CIVITAI_TOKEN}"
# モデル情報の取得
try:
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", headers=headers)
response.raise_for_status()
model_data = response.json()
print("モデル情報を取得しました。")
except Exception as e:
print(f"モデル情報の取得に失敗しました: {e}")
return f"モデル情報の取得に失敗しました: {e}"
# 最新のモデルバージョンを取得
model_versions = model_data.get("modelVersions", [])
if not model_versions:
return "モデルバージョンが見つかりませんでした。"
latest_version = model_versions[0] # 最新バージョンを取得
trigger_words = latest_version.get("trainedWords", [])
files = latest_version.get("files", [])
if not files:
return "モデルファイルが見つかりませんでした。"
download_url = files[0].get("downloadUrl")
if not download_url:
return "ダウンロードURLが見つかりませんでした。"
# ファイルのダウンロード
try:
response = requests.get(download_url, headers=headers)
response.raise_for_status()
print("ファイルをダウンロードしました。")
except Exception as e:
print(f"ファイルのダウンロードに失敗しました: {e}")
return f"ファイルのダウンロードに失敗しました: {e}"
# ファイル名の設定
filename = files[0].get("name", "downloaded_file.bin")
with open(filename, "wb") as f:
f.write(response.content)
print(f"ファイルを保存しました: {filename}")
if repo_type == "model":
target_repo = MODEL_REPO
elif repo_type == "lora":
target_repo = LORA_REPO
elif repo_type == "embedding":
target_repo = EMBE_REPO
else:
return "不正なリポジトリタイプです。"
if not target_repo:
return f"{repo_type}用リポジトリが環境変数で設定されていません。"
# リポジトリ内のファイル一覧を取得
try:
repo_files = api.list_repo_files(repo_id=target_repo, repo_type="dataset", token=HF_TOKEN)
print(f"リポジトリ内のファイル一覧を取得しました。")
except Exception as e:
print(f"リポジトリ内のファイル一覧の取得に失敗しました: {e}")
return f"リポジトリ内のファイル一覧の取得に失敗しました: {e}"
# メタデータファイル名
metadata_filename = "metadata.json"
# リポジトリからメタデータファイルをダウンロード
try:
metadata_file = hf_hub_download(
repo_id=target_repo,
filename=metadata_filename,
repo_type="dataset",
token=HF_TOKEN,
force_download=True
)
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
print("既存のメタデータを読み込みました。")
except Exception as e:
# メタデータファイルが存在しない場合は新規作成
metadata = []
print("メタデータファイルが存在しないため、新規作成します。")
# メタデータを更新(存在しないファイルを削除)
updated_metadata = []
for entry in metadata:
if entry['filename'] in repo_files:
updated_metadata.append(entry)
else:
print(f"存在しないファイルのメタデータを削除: {entry['filename']}")
# ファイル名の重複チェック
existing_filenames = [e['filename'] for e in updated_metadata]
if filename in existing_filenames:
print(f"ファイル名が既に存在します。ファイル名: {filename}")
return f"ファイル名 '{filename}' は既にリポジトリに存在します。"
# 新しい通し番号の設定
sequence_numbers = [e['sequence_number'] for e in updated_metadata]
if sequence_numbers:
new_sequence_number = max(sequence_numbers) + 1
else:
new_sequence_number = 1
# 新しいエントリを追加
new_entry = {
"sequence_number": new_sequence_number,
"filename": filename,
"trigger_words": trigger_words
}
updated_metadata.append(new_entry)
print(f"新しいメタデータを追加しました: {new_entry}")
# メタデータをファイルに保存
with open(metadata_filename, 'w', encoding='utf-8') as f:
json.dump(updated_metadata, f, ensure_ascii=False, indent=4)
print(f"メタデータファイルを更新しました。ファイル名: {metadata_filename}")
# ファイルのアップロード
try:
# 元のファイルをアップロード
api.upload_file(
path_or_fileobj=filename,
path_in_repo=filename,
repo_id=target_repo,
token=HF_TOKEN,
repo_type="dataset",
commit_message=f"Add {filename} from model ID {model_id}"
)
# 更新したメタデータファイルをアップロード
api.upload_file(
path_or_fileobj=metadata_filename,
path_in_repo=metadata_filename,
repo_id=target_repo,
token=HF_TOKEN,
repo_type="dataset",
commit_message="Update metadata"
)
print(f"ファイルとメタデータをアップロードしました。")
return f"ファイル '{filename}' とメタデータを '{target_repo}' にアップロードしました。"
except Exception as e:
print(f"アップロード中にエラーが発生しました: {e}")
return f"アップロード中にエラーが発生しました: {e}"
# Gradio UI構築
with gr.Blocks() as demo:
gr.Markdown("## CivitaiモデルIDからファイルダウンロード&アップロードツール")
gr.Markdown("CivitaiのモデルIDを指定して、'model'、'lora'、または'embedding'リポジトリへアップロードします。")
# ダウンロード&アップロード機能
model_id_input = gr.Textbox(
label="CivitaiモデルID",
placeholder="例: 1102"
)
repo_choice = gr.Radio(
choices=["model", "lora", "embedding"], # "embedding" を追加
label="アップロード先タイプの選択",
value="model"
)
run_button = gr.Button("実行")
output = gr.Textbox(label="結果メッセージ", interactive=False)
run_button.click(
download_and_upload,
inputs=[model_id_input, repo_choice],
outputs=output
)
demo.launch()