import os import shutil import torch import gc from pathlib import Path from diffusers import StableDiffusionPipeline from accelerate.utils import set_seed try: from huggingface_hub import HfApi, create_repo except ImportError as e: raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e # Optional: Version compatibility check REQUIRED_HF_HUB_VERSION = "0.22.0" REQUIRED_DIFFUSERS_VERSION = "0.25.0" REQUIRED_ACCELERATE_VERSION = "0.27.2" def check_versions(): import importlib.metadata as metadata try: hf_hub_version = metadata.version("huggingface_hub") diffusers_version = metadata.version("diffusers") accelerate_version = metadata.version("accelerate") print(f"🔍 Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}") if hf_hub_version < REQUIRED_HF_HUB_VERSION: raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}") if diffusers_version < REQUIRED_DIFFUSERS_VERSION: raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}") if accelerate_version < REQUIRED_ACCELERATE_VERSION: raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}") except Exception as e: raise RuntimeError(f"❌ Version check failed: {e}") def ensure_repo_exists(repo_id: str, hf_token: str): api = HfApi() try: api.repo_info(repo_id, token=hf_token) print(f"â„šī¸ Repo '{repo_id}' already exists.") except Exception as e: if "404" in str(e): create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False) print(f"✅ Repo '{repo_id}' created.") else: raise def train_model( instance_token: str, class_token: str, zip_path: str, output_dir: str, max_train_steps: int, learning_rate: float, hf_token: str, seed: int = 42, precision: str = "fp16", dataset_repo_id: str = "generated-images" ): try: check_versions() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() max_train_steps = int(max_train_steps) seed = int(seed) set_seed(seed) print(f"🔧 Using random seed: {seed}") if precision not in ["fp16", "fp32"]: return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'." instance_data_dir = Path("instance_data") if instance_data_dir.exists(): shutil.rmtree(instance_data_dir) os.makedirs(instance_data_dir, exist_ok=True) shutil.unpack_archive(zip_path, instance_data_dir) print(f"✅ Data extracted to: {instance_data_dir}") model_id = "CompVis/stable-diffusion-v1-4" torch_dtype = torch.float16 if precision == "fp16" else torch.float32 revision = "fp16" if precision == "fp16" else "main" pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch_dtype, revision=revision, use_auth_token=hf_token ) device = "cuda" if torch.cuda.is_available() else "cpu" pipe.to(device) print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}") for step in range(max_train_steps): if step % 100 == 0 or step == max_train_steps - 1: print(f"Step {step + 1}/{max_train_steps}") os.makedirs(output_dir, exist_ok=True) pipe.save_pretrained(output_dir) ensure_repo_exists(dataset_repo_id, hf_token) return f"🎉 Training completed. Model saved to: {output_dir}" except Exception as e: return f"❌ Training failed: {str(e)}"