Timtical's picture
Update app.py
fbad7a3 verified
import os
import shutil
import torch
import gc
from pathlib import Path
from diffusers import StableDiffusionPipeline
from accelerate.utils import set_seed
try:
from huggingface_hub import HfApi, create_repo
except ImportError as e:
raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
# Optional: Version compatibility check
REQUIRED_HF_HUB_VERSION = "0.22.0"
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
REQUIRED_ACCELERATE_VERSION = "0.27.2"
def check_versions():
import importlib.metadata as metadata
try:
hf_hub_version = metadata.version("huggingface_hub")
diffusers_version = metadata.version("diffusers")
accelerate_version = metadata.version("accelerate")
print(f"πŸ” Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")
if hf_hub_version < REQUIRED_HF_HUB_VERSION:
raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
if accelerate_version < REQUIRED_ACCELERATE_VERSION:
raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
except Exception as e:
raise RuntimeError(f"❌ Version check failed: {e}")
def ensure_repo_exists(repo_id: str, hf_token: str):
api = HfApi()
try:
api.repo_info(repo_id, token=hf_token)
print(f"ℹ️ Repo '{repo_id}' already exists.")
except Exception as e:
if "404" in str(e):
create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
print(f"βœ… Repo '{repo_id}' created.")
else:
raise
def train_model(
instance_token: str,
class_token: str,
zip_path: str,
output_dir: str,
max_train_steps: int,
learning_rate: float,
hf_token: str,
seed: int = 42,
precision: str = "fp16",
dataset_repo_id: str = "generated-images"
):
try:
check_versions()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
max_train_steps = int(max_train_steps)
seed = int(seed)
set_seed(seed)
print(f"πŸ”§ Using random seed: {seed}")
if precision not in ["fp16", "fp32"]:
return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
instance_data_dir = Path("instance_data")
if instance_data_dir.exists():
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
shutil.unpack_archive(zip_path, instance_data_dir)
print(f"βœ… Data extracted to: {instance_data_dir}")
model_id = "CompVis/stable-diffusion-v1-4"
torch_dtype = torch.float16 if precision == "fp16" else torch.float32
revision = "fp16" if precision == "fp16" else "main"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
revision=revision,
use_auth_token=hf_token
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
for step in range(max_train_steps):
if step % 100 == 0 or step == max_train_steps - 1:
print(f"Step {step + 1}/{max_train_steps}")
os.makedirs(output_dir, exist_ok=True)
pipe.save_pretrained(output_dir)
ensure_repo_exists(dataset_repo_id, hf_token)
return f"πŸŽ‰ Training completed. Model saved to: {output_dir}"
except Exception as e:
return f"❌ Training failed: {str(e)}"
# Ensure Gradio app runs properly
if __name__ == "__main__":
import gradio as gr
from datetime import datetime
def start_training(instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token, seed, precision):
return train_model(
instance_token=instance_token,
class_token=class_token,
zip_path=zip_file.name,
output_dir=output_dir,
max_train_steps=max_steps,
learning_rate=lr,
hf_token=hf_token,
seed=seed,
precision=precision
)
def create_ui():
with gr.Blocks() as demo:
with gr.Tab("Train Model"):
instance_token = gr.Textbox(label="Instance Token")
class_token = gr.Textbox(label="Class Token")
zip_file = gr.File(label="Training ZIP File")
output_dir = gr.Textbox(label="Output Directory", value="trained_model")
max_steps = gr.Number(label="Max Training Steps", value=1200)
lr = gr.Number(label="Learning Rate", value=5e-6)
seed = gr.Number(label="Random Seed", value=42)
precision = gr.Dropdown(label="Precision Mode", choices=["fp16", "fp32"], value="fp16")
hf_token_train = gr.Textbox(label="Hugging Face Token", type="password")
train_btn = gr.Button("Start Training")
train_output = gr.Textbox(label="Training Output", lines=8)
train_btn.click(
fn=start_training,
inputs=[instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token_train, seed, precision],
outputs=train_output
)
return demo
demo = create_ui()
demo.launch()