Spaces:
Sleeping
Sleeping
File size: 3,876 Bytes
7a1f255 7aa015e 0003569 24767bd 3116e1f 3666c1e 7a1f255 5807b1c a56a388 7a1f255 0003569 7a1f255 b02edaf 5807b1c fb097b4 5807b1c 7aa015e 7a1f255 b6c0c92 7a1f255 5807b1c 7a1f255 5807b1c 7a1f255 b02edaf fb097b4 b02edaf 7a1f255 24767bd a56a388 24767bd b6c0c92 7a1f255 b6c0c92 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import os
import shutil
import torch
import gc
from pathlib import Path
from diffusers import StableDiffusionPipeline
from accelerate.utils import set_seed
try:
from huggingface_hub import HfApi, create_repo
except ImportError as e:
raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
# Optional: Version compatibility check
REQUIRED_HF_HUB_VERSION = "0.22.0"
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
REQUIRED_ACCELERATE_VERSION = "0.27.2"
def check_versions():
import importlib.metadata as metadata
try:
hf_hub_version = metadata.version("huggingface_hub")
diffusers_version = metadata.version("diffusers")
accelerate_version = metadata.version("accelerate")
print(f"π Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")
if hf_hub_version < REQUIRED_HF_HUB_VERSION:
raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
if accelerate_version < REQUIRED_ACCELERATE_VERSION:
raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
except Exception as e:
raise RuntimeError(f"β Version check failed: {e}")
def ensure_repo_exists(repo_id: str, hf_token: str):
api = HfApi()
try:
api.repo_info(repo_id, token=hf_token)
print(f"βΉοΈ Repo '{repo_id}' already exists.")
except Exception as e:
if "404" in str(e):
create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
print(f"β
Repo '{repo_id}' created.")
else:
raise
def train_model(
instance_token: str,
class_token: str,
zip_path: str,
output_dir: str,
max_train_steps: int,
learning_rate: float,
hf_token: str,
seed: int = 42,
precision: str = "fp16",
dataset_repo_id: str = "generated-images"
):
try:
check_versions()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
max_train_steps = int(max_train_steps)
seed = int(seed)
set_seed(seed)
print(f"π§ Using random seed: {seed}")
if precision not in ["fp16", "fp32"]:
return f"β Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
instance_data_dir = Path("instance_data")
if instance_data_dir.exists():
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
shutil.unpack_archive(zip_path, instance_data_dir)
print(f"β
Data extracted to: {instance_data_dir}")
model_id = "CompVis/stable-diffusion-v1-4"
torch_dtype = torch.float16 if precision == "fp16" else torch.float32
revision = "fp16" if precision == "fp16" else "main"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
revision=revision,
use_auth_token=hf_token
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
print(f"π§ Simulating training for {max_train_steps} steps at LR={learning_rate}")
for step in range(max_train_steps):
if step % 100 == 0 or step == max_train_steps - 1:
print(f"Step {step + 1}/{max_train_steps}")
os.makedirs(output_dir, exist_ok=True)
pipe.save_pretrained(output_dir)
ensure_repo_exists(dataset_repo_id, hf_token)
return f"π Training completed. Model saved to: {output_dir}"
except Exception as e:
return f"β Training failed: {str(e)}"
|