Spaces:
Sleeping
Sleeping
File size: 5,690 Bytes
a7fce53 92e3d48 725c992 92e3d48 725c992 92e3d48 a53bf22 92e3d48 725c992 92e3d48 20ae52f 92e3d48 725c992 92e3d48 62defe9 92e3d48 0adab8d 92e3d48 8926422 92e3d48 81b6c1c 92e3d48 8926422 fbad7a3 8926422 f810ee4 fbad7a3 f810ee4 8926422 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import os
import shutil
import torch
import gc
from pathlib import Path
from diffusers import StableDiffusionPipeline
from accelerate.utils import set_seed
try:
from huggingface_hub import HfApi, create_repo
except ImportError as e:
raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
# Optional: Version compatibility check
REQUIRED_HF_HUB_VERSION = "0.22.0"
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
REQUIRED_ACCELERATE_VERSION = "0.27.2"
def check_versions():
import importlib.metadata as metadata
try:
hf_hub_version = metadata.version("huggingface_hub")
diffusers_version = metadata.version("diffusers")
accelerate_version = metadata.version("accelerate")
print(f"π Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")
if hf_hub_version < REQUIRED_HF_HUB_VERSION:
raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
if accelerate_version < REQUIRED_ACCELERATE_VERSION:
raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
except Exception as e:
raise RuntimeError(f"β Version check failed: {e}")
def ensure_repo_exists(repo_id: str, hf_token: str):
api = HfApi()
try:
api.repo_info(repo_id, token=hf_token)
print(f"βΉοΈ Repo '{repo_id}' already exists.")
except Exception as e:
if "404" in str(e):
create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
print(f"β
Repo '{repo_id}' created.")
else:
raise
def train_model(
instance_token: str,
class_token: str,
zip_path: str,
output_dir: str,
max_train_steps: int,
learning_rate: float,
hf_token: str,
seed: int = 42,
precision: str = "fp16",
dataset_repo_id: str = "generated-images"
):
try:
check_versions()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
max_train_steps = int(max_train_steps)
seed = int(seed)
set_seed(seed)
print(f"π§ Using random seed: {seed}")
if precision not in ["fp16", "fp32"]:
return f"β Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
instance_data_dir = Path("instance_data")
if instance_data_dir.exists():
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
shutil.unpack_archive(zip_path, instance_data_dir)
print(f"β
Data extracted to: {instance_data_dir}")
model_id = "CompVis/stable-diffusion-v1-4"
torch_dtype = torch.float16 if precision == "fp16" else torch.float32
revision = "fp16" if precision == "fp16" else "main"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
revision=revision,
use_auth_token=hf_token
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
print(f"π§ Simulating training for {max_train_steps} steps at LR={learning_rate}")
for step in range(max_train_steps):
if step % 100 == 0 or step == max_train_steps - 1:
print(f"Step {step + 1}/{max_train_steps}")
os.makedirs(output_dir, exist_ok=True)
pipe.save_pretrained(output_dir)
ensure_repo_exists(dataset_repo_id, hf_token)
return f"π Training completed. Model saved to: {output_dir}"
except Exception as e:
return f"β Training failed: {str(e)}"
# Ensure Gradio app runs properly
if __name__ == "__main__":
import gradio as gr
from datetime import datetime
def start_training(instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token, seed, precision):
return train_model(
instance_token=instance_token,
class_token=class_token,
zip_path=zip_file.name,
output_dir=output_dir,
max_train_steps=max_steps,
learning_rate=lr,
hf_token=hf_token,
seed=seed,
precision=precision
)
def create_ui():
with gr.Blocks() as demo:
with gr.Tab("Train Model"):
instance_token = gr.Textbox(label="Instance Token")
class_token = gr.Textbox(label="Class Token")
zip_file = gr.File(label="Training ZIP File")
output_dir = gr.Textbox(label="Output Directory", value="trained_model")
max_steps = gr.Number(label="Max Training Steps", value=1200)
lr = gr.Number(label="Learning Rate", value=5e-6)
seed = gr.Number(label="Random Seed", value=42)
precision = gr.Dropdown(label="Precision Mode", choices=["fp16", "fp32"], value="fp16")
hf_token_train = gr.Textbox(label="Hugging Face Token", type="password")
train_btn = gr.Button("Start Training")
train_output = gr.Textbox(label="Training Output", lines=8)
train_btn.click(
fn=start_training,
inputs=[instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token_train, seed, precision],
outputs=train_output
)
return demo
demo = create_ui()
demo.launch()
|