my-character-creator / train_model.py
Timtical's picture
Update train_model.py
0003569 verified
import os
import shutil
import torch
import gc
from pathlib import Path
from diffusers import StableDiffusionPipeline
from accelerate.utils import set_seed
try:
from huggingface_hub import HfApi, create_repo
except ImportError as e:
raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
# Optional: Version compatibility check
REQUIRED_HF_HUB_VERSION = "0.22.0"
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
REQUIRED_ACCELERATE_VERSION = "0.27.2"
def check_versions():
import importlib.metadata as metadata
try:
hf_hub_version = metadata.version("huggingface_hub")
diffusers_version = metadata.version("diffusers")
accelerate_version = metadata.version("accelerate")
print(f"πŸ” Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")
if hf_hub_version < REQUIRED_HF_HUB_VERSION:
raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
if accelerate_version < REQUIRED_ACCELERATE_VERSION:
raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
except Exception as e:
raise RuntimeError(f"❌ Version check failed: {e}")
def ensure_repo_exists(repo_id: str, hf_token: str):
api = HfApi()
try:
api.repo_info(repo_id, token=hf_token)
print(f"ℹ️ Repo '{repo_id}' already exists.")
except Exception as e:
if "404" in str(e):
create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
print(f"βœ… Repo '{repo_id}' created.")
else:
raise
def train_model(
instance_token: str,
class_token: str,
zip_path: str,
output_dir: str,
max_train_steps: int,
learning_rate: float,
hf_token: str,
seed: int = 42,
precision: str = "fp16",
dataset_repo_id: str = "generated-images"
):
try:
check_versions()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
max_train_steps = int(max_train_steps)
seed = int(seed)
set_seed(seed)
print(f"πŸ”§ Using random seed: {seed}")
if precision not in ["fp16", "fp32"]:
return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
instance_data_dir = Path("instance_data")
if instance_data_dir.exists():
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
shutil.unpack_archive(zip_path, instance_data_dir)
print(f"βœ… Data extracted to: {instance_data_dir}")
model_id = "CompVis/stable-diffusion-v1-4"
torch_dtype = torch.float16 if precision == "fp16" else torch.float32
revision = "fp16" if precision == "fp16" else "main"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
revision=revision,
use_auth_token=hf_token
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
for step in range(max_train_steps):
if step % 100 == 0 or step == max_train_steps - 1:
print(f"Step {step + 1}/{max_train_steps}")
os.makedirs(output_dir, exist_ok=True)
pipe.save_pretrained(output_dir)
ensure_repo_exists(dataset_repo_id, hf_token)
return f"πŸŽ‰ Training completed. Model saved to: {output_dir}"
except Exception as e:
return f"❌ Training failed: {str(e)}"