File size: 3,876 Bytes
7a1f255
 
 
 
 
 
 
7aa015e
 
 
 
 
 
0003569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24767bd
 
 
 
 
 
3116e1f
 
3666c1e
 
 
 
7a1f255
 
 
 
 
 
 
 
5807b1c
 
a56a388
 
7a1f255
 
0003569
 
7a1f255
b02edaf
 
5807b1c
fb097b4
 
5807b1c
7aa015e
 
 
 
7a1f255
 
 
 
 
 
b6c0c92
7a1f255
 
 
5807b1c
 
 
7a1f255
 
5807b1c
 
7a1f255
 
 
b02edaf
 
 
 
fb097b4
 
b02edaf
7a1f255
 
 
24767bd
a56a388
24767bd
b6c0c92
7a1f255
 
b6c0c92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import shutil
import torch
import gc
from pathlib import Path
from diffusers import StableDiffusionPipeline
from accelerate.utils import set_seed

try:
    from huggingface_hub import HfApi, create_repo
except ImportError as e:
    raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e

# Optional: Version compatibility check
REQUIRED_HF_HUB_VERSION = "0.22.0"
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
REQUIRED_ACCELERATE_VERSION = "0.27.2"

def check_versions():
    import importlib.metadata as metadata
    try:
        hf_hub_version = metadata.version("huggingface_hub")
        diffusers_version = metadata.version("diffusers")
        accelerate_version = metadata.version("accelerate")

        print(f"πŸ” Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")

        if hf_hub_version < REQUIRED_HF_HUB_VERSION:
            raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
        if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
            raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
        if accelerate_version < REQUIRED_ACCELERATE_VERSION:
            raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
    except Exception as e:
        raise RuntimeError(f"❌ Version check failed: {e}")

def ensure_repo_exists(repo_id: str, hf_token: str):
    api = HfApi()
    try:
        api.repo_info(repo_id, token=hf_token)
        print(f"ℹ️ Repo '{repo_id}' already exists.")
    except Exception as e:
        if "404" in str(e):
            create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
            print(f"βœ… Repo '{repo_id}' created.")
        else:
            raise

def train_model(
    instance_token: str,
    class_token: str,
    zip_path: str,
    output_dir: str,
    max_train_steps: int,
    learning_rate: float,
    hf_token: str,
    seed: int = 42,
    precision: str = "fp16",
    dataset_repo_id: str = "generated-images"
):
    try:
        check_versions()

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        max_train_steps = int(max_train_steps)
        seed = int(seed)
        set_seed(seed)
        print(f"πŸ”§ Using random seed: {seed}")

        if precision not in ["fp16", "fp32"]:
            return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."

        instance_data_dir = Path("instance_data")
        if instance_data_dir.exists():
            shutil.rmtree(instance_data_dir)
        os.makedirs(instance_data_dir, exist_ok=True)
        shutil.unpack_archive(zip_path, instance_data_dir)

        print(f"βœ… Data extracted to: {instance_data_dir}")

        model_id = "CompVis/stable-diffusion-v1-4"
        torch_dtype = torch.float16 if precision == "fp16" else torch.float32
        revision = "fp16" if precision == "fp16" else "main"

        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            revision=revision,
            use_auth_token=hf_token
        )

        device = "cuda" if torch.cuda.is_available() else "cpu"
        pipe.to(device)

        print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
        for step in range(max_train_steps):
            if step % 100 == 0 or step == max_train_steps - 1:
                print(f"Step {step + 1}/{max_train_steps}")

        os.makedirs(output_dir, exist_ok=True)
        pipe.save_pretrained(output_dir)

        ensure_repo_exists(dataset_repo_id, hf_token)

        return f"πŸŽ‰ Training completed. Model saved to: {output_dir}"

    except Exception as e:
        return f"❌ Training failed: {str(e)}"