import os import shutil import torch import gc from pathlib import Path from diffusers import StableDiffusionPipeline from accelerate.utils import set_seed try: from huggingface_hub import HfApi, create_repo except ImportError as e: raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e # Optional: Version compatibility check REQUIRED_HF_HUB_VERSION = "0.22.0" REQUIRED_DIFFUSERS_VERSION = "0.25.0" REQUIRED_ACCELERATE_VERSION = "0.27.2" def check_versions(): import importlib.metadata as metadata try: hf_hub_version = metadata.version("huggingface_hub") diffusers_version = metadata.version("diffusers") accelerate_version = metadata.version("accelerate") print(f"🔍 Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}") if hf_hub_version < REQUIRED_HF_HUB_VERSION: raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}") if diffusers_version < REQUIRED_DIFFUSERS_VERSION: raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}") if accelerate_version < REQUIRED_ACCELERATE_VERSION: raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}") except Exception as e: raise RuntimeError(f"❌ Version check failed: {e}") def ensure_repo_exists(repo_id: str, hf_token: str): api = HfApi() try: api.repo_info(repo_id, token=hf_token) print(f"â„šī¸ Repo '{repo_id}' already exists.") except Exception as e: if "404" in str(e): create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False) print(f"✅ Repo '{repo_id}' created.") else: raise def train_model( instance_token: str, class_token: str, zip_path: str, output_dir: str, max_train_steps: int, learning_rate: float, hf_token: str, seed: int = 42, precision: str = "fp16", dataset_repo_id: str = "generated-images" ): try: check_versions() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() max_train_steps = int(max_train_steps) seed = int(seed) set_seed(seed) print(f"🔧 Using random seed: {seed}") if precision not in ["fp16", "fp32"]: return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'." instance_data_dir = Path("instance_data") if instance_data_dir.exists(): shutil.rmtree(instance_data_dir) os.makedirs(instance_data_dir, exist_ok=True) shutil.unpack_archive(zip_path, instance_data_dir) print(f"✅ Data extracted to: {instance_data_dir}") model_id = "CompVis/stable-diffusion-v1-4" torch_dtype = torch.float16 if precision == "fp16" else torch.float32 revision = "fp16" if precision == "fp16" else "main" pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch_dtype, revision=revision, use_auth_token=hf_token ) device = "cuda" if torch.cuda.is_available() else "cpu" pipe.to(device) print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}") for step in range(max_train_steps): if step % 100 == 0 or step == max_train_steps - 1: print(f"Step {step + 1}/{max_train_steps}") os.makedirs(output_dir, exist_ok=True) pipe.save_pretrained(output_dir) ensure_repo_exists(dataset_repo_id, hf_token) return f"🎉 Training completed. Model saved to: {output_dir}" except Exception as e: return f"❌ Training failed: {str(e)}" # Ensure Gradio app runs properly if __name__ == "__main__": import gradio as gr from datetime import datetime def start_training(instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token, seed, precision): return train_model( instance_token=instance_token, class_token=class_token, zip_path=zip_file.name, output_dir=output_dir, max_train_steps=max_steps, learning_rate=lr, hf_token=hf_token, seed=seed, precision=precision ) def create_ui(): with gr.Blocks() as demo: with gr.Tab("Train Model"): instance_token = gr.Textbox(label="Instance Token") class_token = gr.Textbox(label="Class Token") zip_file = gr.File(label="Training ZIP File") output_dir = gr.Textbox(label="Output Directory", value="trained_model") max_steps = gr.Number(label="Max Training Steps", value=1200) lr = gr.Number(label="Learning Rate", value=5e-6) seed = gr.Number(label="Random Seed", value=42) precision = gr.Dropdown(label="Precision Mode", choices=["fp16", "fp32"], value="fp16") hf_token_train = gr.Textbox(label="Hugging Face Token", type="password") train_btn = gr.Button("Start Training") train_output = gr.Textbox(label="Training Output", lines=8) train_btn.click( fn=start_training, inputs=[instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token_train, seed, precision], outputs=train_output ) return demo demo = create_ui() demo.launch()