File size: 5,690 Bytes
a7fce53
92e3d48
 
 
 
 
 
725c992
92e3d48
 
 
 
725c992
92e3d48
 
 
 
a53bf22
92e3d48
 
725c992
92e3d48
 
 
20ae52f
92e3d48
 
 
 
 
 
 
 
725c992
92e3d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62defe9
92e3d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0adab8d
 
92e3d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8926422
92e3d48
 
 
81b6c1c
92e3d48
8926422
 
 
 
 
 
fbad7a3
 
 
 
 
 
 
 
 
 
 
 
 
8926422
f810ee4
fbad7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f810ee4
8926422
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import shutil
import torch
import gc
from pathlib import Path
from diffusers import StableDiffusionPipeline
from accelerate.utils import set_seed

try:
    from huggingface_hub import HfApi, create_repo
except ImportError as e:
    raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e

# Optional: Version compatibility check
REQUIRED_HF_HUB_VERSION = "0.22.0"
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
REQUIRED_ACCELERATE_VERSION = "0.27.2"

def check_versions():
    import importlib.metadata as metadata
    try:
        hf_hub_version = metadata.version("huggingface_hub")
        diffusers_version = metadata.version("diffusers")
        accelerate_version = metadata.version("accelerate")

        print(f"πŸ” Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")

        if hf_hub_version < REQUIRED_HF_HUB_VERSION:
            raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
        if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
            raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
        if accelerate_version < REQUIRED_ACCELERATE_VERSION:
            raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
    except Exception as e:
        raise RuntimeError(f"❌ Version check failed: {e}")

def ensure_repo_exists(repo_id: str, hf_token: str):
    api = HfApi()
    try:
        api.repo_info(repo_id, token=hf_token)
        print(f"ℹ️ Repo '{repo_id}' already exists.")
    except Exception as e:
        if "404" in str(e):
            create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
            print(f"βœ… Repo '{repo_id}' created.")
        else:
            raise

def train_model(
    instance_token: str,
    class_token: str,
    zip_path: str,
    output_dir: str,
    max_train_steps: int,
    learning_rate: float,
    hf_token: str,
    seed: int = 42,
    precision: str = "fp16",
    dataset_repo_id: str = "generated-images"
):
    try:
        check_versions()

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        max_train_steps = int(max_train_steps)
        seed = int(seed)
        set_seed(seed)
        print(f"πŸ”§ Using random seed: {seed}")

        if precision not in ["fp16", "fp32"]:
            return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."

        instance_data_dir = Path("instance_data")
        if instance_data_dir.exists():
            shutil.rmtree(instance_data_dir)
        os.makedirs(instance_data_dir, exist_ok=True)
        shutil.unpack_archive(zip_path, instance_data_dir)

        print(f"βœ… Data extracted to: {instance_data_dir}")

        model_id = "CompVis/stable-diffusion-v1-4"
        torch_dtype = torch.float16 if precision == "fp16" else torch.float32
        revision = "fp16" if precision == "fp16" else "main"

        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            revision=revision,
            use_auth_token=hf_token
        )

        device = "cuda" if torch.cuda.is_available() else "cpu"
        pipe.to(device)

        print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
        for step in range(max_train_steps):
            if step % 100 == 0 or step == max_train_steps - 1:
                print(f"Step {step + 1}/{max_train_steps}")

        os.makedirs(output_dir, exist_ok=True)
        pipe.save_pretrained(output_dir)

        ensure_repo_exists(dataset_repo_id, hf_token)

        return f"πŸŽ‰ Training completed. Model saved to: {output_dir}"

    except Exception as e:
        return f"❌ Training failed: {str(e)}"

# Ensure Gradio app runs properly
if __name__ == "__main__":
    import gradio as gr
    from datetime import datetime

    def start_training(instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token, seed, precision):
        return train_model(
            instance_token=instance_token,
            class_token=class_token,
            zip_path=zip_file.name,
            output_dir=output_dir,
            max_train_steps=max_steps,
            learning_rate=lr,
            hf_token=hf_token,
            seed=seed,
            precision=precision
        )

    def create_ui():
        with gr.Blocks() as demo:
            with gr.Tab("Train Model"):
                instance_token = gr.Textbox(label="Instance Token")
                class_token = gr.Textbox(label="Class Token")
                zip_file = gr.File(label="Training ZIP File")
                output_dir = gr.Textbox(label="Output Directory", value="trained_model")
                max_steps = gr.Number(label="Max Training Steps", value=1200)
                lr = gr.Number(label="Learning Rate", value=5e-6)
                seed = gr.Number(label="Random Seed", value=42)
                precision = gr.Dropdown(label="Precision Mode", choices=["fp16", "fp32"], value="fp16")
                hf_token_train = gr.Textbox(label="Hugging Face Token", type="password")
                train_btn = gr.Button("Start Training")
                train_output = gr.Textbox(label="Training Output", lines=8)

                train_btn.click(
                    fn=start_training,
                    inputs=[instance_token, class_token, zip_file, output_dir, max_steps, lr, hf_token_train, seed, precision],
                    outputs=train_output
                )
        return demo

    demo = create_ui()
    demo.launch()