Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import torch | |
| import gc | |
| from pathlib import Path | |
| from diffusers import StableDiffusionPipeline | |
| from accelerate.utils import set_seed | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| except ImportError as e: | |
| raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e | |
| # Optional: Version compatibility check | |
| REQUIRED_HF_HUB_VERSION = "0.22.0" | |
| REQUIRED_DIFFUSERS_VERSION = "0.25.0" | |
| REQUIRED_ACCELERATE_VERSION = "0.27.2" | |
| def check_versions(): | |
| import importlib.metadata as metadata | |
| try: | |
| hf_hub_version = metadata.version("huggingface_hub") | |
| diffusers_version = metadata.version("diffusers") | |
| accelerate_version = metadata.version("accelerate") | |
| print(f"π Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}") | |
| if hf_hub_version < REQUIRED_HF_HUB_VERSION: | |
| raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}") | |
| if diffusers_version < REQUIRED_DIFFUSERS_VERSION: | |
| raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}") | |
| if accelerate_version < REQUIRED_ACCELERATE_VERSION: | |
| raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}") | |
| except Exception as e: | |
| raise RuntimeError(f"β Version check failed: {e}") | |
| def ensure_repo_exists(repo_id: str, hf_token: str): | |
| api = HfApi() | |
| try: | |
| api.repo_info(repo_id, token=hf_token) | |
| print(f"βΉοΈ Repo '{repo_id}' already exists.") | |
| except Exception as e: | |
| if "404" in str(e): | |
| create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False) | |
| print(f"β Repo '{repo_id}' created.") | |
| else: | |
| raise | |
| def train_model( | |
| instance_token: str, | |
| class_token: str, | |
| zip_path: str, | |
| output_dir: str, | |
| max_train_steps: int, | |
| learning_rate: float, | |
| hf_token: str, | |
| seed: int = 42, | |
| precision: str = "fp16", | |
| dataset_repo_id: str = "generated-images" | |
| ): | |
| try: | |
| check_versions() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| max_train_steps = int(max_train_steps) | |
| seed = int(seed) | |
| set_seed(seed) | |
| print(f"π§ Using random seed: {seed}") | |
| if precision not in ["fp16", "fp32"]: | |
| return f"β Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'." | |
| instance_data_dir = Path("instance_data") | |
| if instance_data_dir.exists(): | |
| shutil.rmtree(instance_data_dir) | |
| os.makedirs(instance_data_dir, exist_ok=True) | |
| shutil.unpack_archive(zip_path, instance_data_dir) | |
| print(f"β Data extracted to: {instance_data_dir}") | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| torch_dtype = torch.float16 if precision == "fp16" else torch.float32 | |
| revision = "fp16" if precision == "fp16" else "main" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| revision=revision, | |
| use_auth_token=hf_token | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe.to(device) | |
| print(f"π§ Simulating training for {max_train_steps} steps at LR={learning_rate}") | |
| for step in range(max_train_steps): | |
| if step % 100 == 0 or step == max_train_steps - 1: | |
| print(f"Step {step + 1}/{max_train_steps}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| pipe.save_pretrained(output_dir) | |
| ensure_repo_exists(dataset_repo_id, hf_token) | |
| return f"π Training completed. Model saved to: {output_dir}" | |
| except Exception as e: | |
| return f"β Training failed: {str(e)}" | |