Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import torch | |
| import gc | |
| from pathlib import Path | |
| from diffusers import StableDiffusionPipeline | |
| from accelerate.utils import set_seed | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| except ImportError as e: | |
| raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e | |
| # Optional: Version compatibility check | |
| REQUIRED_HF_HUB_VERSION = "0.22.0" | |
| REQUIRED_DIFFUSERS_VERSION = "0.25.0" | |
| REQUIRED_ACCELERATE_VERSION = "0.27.2" | |
| def check_versions(): | |
| import importlib.metadata as metadata | |
| try: | |
| hf_hub_version = metadata.version("huggingface_hub") | |
| diffusers_version = metadata.version("diffusers") | |
| accelerate_version = metadata.version("accelerate") | |
| print(f"π Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}") | |
| if hf_hub_version < REQUIRED_HF_HUB_VERSION: | |
| raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}") | |
| if diffusers_version < REQUIRED_DIFFUSERS_VERSION: | |
| raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}") | |
| if accelerate_version < REQUIRED_ACCELERATE_VERSION: | |
| raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}") | |
| except Exception as e: | |
| raise RuntimeError(f"β Version check failed: {e}") | |
| def ensure_repo_exists(repo_id: str, hf_token: str): | |
| api = HfApi() | |
| try: | |
| api.repo_info(repo_id, token=hf_token) | |
| print(f"βΉοΈ Repo '{repo_id}' already exists.") | |
| except Exception as e: | |
| if "404" in str(e): | |
| create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False) | |
| print(f"β Repo '{repo_id}' created.") | |
| else: | |
| raise | |
| def train_model( | |
| instance_token: str, | |
| class_token: str, | |
| zip_path: str, | |
| output_dir: str, | |
| max_train_steps: int, | |
| learning_rate: float, | |
| hf_token: str, | |
| seed: int = 42, | |
| precision: str = "fp16", | |
| dataset_repo_id: str = "generated-images" | |
| ): | |
| try: | |
| check_versions() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| max_train_steps = int(max_train_steps) | |
| seed = int(seed) | |
| set_seed(seed) | |
| print(f"π§ Using random seed: {seed}") | |
| if precision not in ["fp16", "fp32"]: | |
| return f"β Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'." | |
| instance_data_dir = Path("instance_data") | |
| if instance_data_dir.exists(): | |
| shutil.rmtree(instance_data_dir) | |
| os.makedirs(instance_data_dir, exist_ok=True) | |
| shutil.unpack_archive(zip_path, instance_data_dir) | |
| print(f"β Data extracted to: {instance_data_dir}") | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| torch_dtype = torch.float16 if precision == "fp16" else torch.float32 | |
| revision = "fp16" if precision == "fp16" else "main" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| revision=revision, | |
| use_auth_token=hf_token | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe.to(device) | |
| print(f"π§ Simulating training for {max_train_steps} steps at LR={learning_rate}") | |
| for step in range(max_train_steps): | |
| if step % 100 == 0 or step == max_train_steps - 1: | |
| print(f"Step {step + 1}/{max_train_steps}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| pipe.save_pretrained(output_dir) | |
| ensure_repo_exists(dataset_repo_id, hf_token) | |
| return f"π Training completed. Model saved to: {output_dir}" | |
| except Exception as e: | |
| return f"β Training failed: {str(e)}" | |
| # Ensure Gradio app runs properly | |
| if __name__ == "__main__": | |
| import gradio as gr | |
| from datetime import datetime | |
| def start_training(instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token, seed, precision): | |
| return train_model( | |
| instance_token=instance_token, | |
| class_token=class_token, | |
| zip_path=zip_file.name, | |
| output_dir=output_dir, | |
| max_train_steps=max_steps, | |
| learning_rate=lr, | |
| hf_token=hf_token, | |
| seed=seed, | |
| precision=precision | |
| ) | |
| def create_ui(): | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Train Model"): | |
| instance_token = gr.Textbox(label="Instance Token") | |
| class_token = gr.Textbox(label="Class Token") | |
| zip_file = gr.File(label="Training ZIP File") | |
| output_dir = gr.Textbox(label="Output Directory", value="trained_model") | |
| max_steps = gr.Number(label="Max Training Steps", value=1200) | |
| lr = gr.Number(label="Learning Rate", value=5e-6) | |
| seed = gr.Number(label="Random Seed", value=42) | |
| precision = gr.Dropdown(label="Precision Mode", choices=["fp16", "fp32"], value="fp16") | |
| hf_token_train = gr.Textbox(label="Hugging Face Token", type="password") | |
| train_btn = gr.Button("Start Training") | |
| train_output = gr.Textbox(label="Training Output", lines=8) | |
| train_btn.click( | |
| fn=start_training, | |
| inputs=[instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token_train, seed, precision], | |
| outputs=train_output | |
| ) | |
| return demo | |
| demo = create_ui() | |
| demo.launch() | |