File size: 9,792 Bytes
553a5e7
 
32c80ed
 
 
 
 
553a5e7
 
 
 
 
 
0bde4c8
 
553a5e7
 
32c80ed
0bde4c8
 
 
 
 
 
 
 
 
 
 
 
 
553a5e7
11c83cf
553a5e7
 
 
 
 
 
 
 
 
0bde4c8
32c80ed
553a5e7
 
 
 
32c80ed
 
553a5e7
 
32c80ed
 
 
 
553a5e7
0bde4c8
 
553a5e7
 
11c83cf
553a5e7
 
11c83cf
 
 
 
0bde4c8
11c83cf
0bde4c8
11c83cf
 
 
 
0bde4c8
9d443a8
 
 
 
0bde4c8
32c80ed
553a5e7
11c83cf
553a5e7
 
0bde4c8
553a5e7
32c80ed
553a5e7
 
32c80ed
 
553a5e7
 
 
32c80ed
0bde4c8
 
 
553a5e7
 
 
 
32c80ed
 
 
 
 
 
553a5e7
32c80ed
0bde4c8
553a5e7
32c80ed
553a5e7
 
32c80ed
 
0bde4c8
 
32c80ed
553a5e7
 
 
0bde4c8
 
553a5e7
 
11c83cf
553a5e7
 
32c80ed
 
553a5e7
32c80ed
 
 
11c83cf
32c80ed
0bde4c8
 
32c80ed
553a5e7
 
 
 
 
0bde4c8
553a5e7
0bde4c8
553a5e7
 
 
0bde4c8
11c83cf
 
9f30647
 
0bde4c8
11c83cf
0bde4c8
553a5e7
0bde4c8
32c80ed
0bde4c8
11c83cf
0bde4c8
 
32c80ed
 
 
 
0bde4c8
9f30647
 
 
11c83cf
0bde4c8
 
 
 
 
 
9f30647
32c80ed
 
 
 
0bde4c8
 
 
 
 
 
32c80ed
 
0bde4c8
 
 
32c80ed
0bde4c8
 
 
11c83cf
0bde4c8
11c83cf
553a5e7
 
 
11c83cf
 
 
553a5e7
 
32c80ed
11c83cf
32c80ed
 
11c83cf
32c80ed
 
11c83cf
 
 
32c80ed
553a5e7
32c80ed
11c83cf
32c80ed
 
 
 
 
 
 
11c83cf
 
 
32c80ed
 
 
11c83cf
553a5e7
32c80ed
11c83cf
553a5e7
11c83cf
 
32c80ed
11c83cf
 
32c80ed
11c83cf
32c80ed
11c83cf
9d443a8
32c80ed
11c83cf
 
553a5e7
32c80ed
11c83cf
553a5e7
11c83cf
0bde4c8
 
11c83cf
 
553a5e7
 
0bde4c8
553a5e7
32c80ed
 
 
 
553a5e7
0bde4c8
553a5e7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import gradio as gr
import torch
import os
import threading
import queue
import time
import json
from transformers import (
    GPT2Config, 
    GPT2LMHeadModel, 
    GPT2Tokenizer, 
    Trainer, 
    TrainingArguments, 
    DataCollatorForLanguageModeling,
    TrainerCallback
)
from datasets import load_dataset
from huggingface_hub import whoami, HfApi

# --- Helper Classes ---

class LogQueueCallback(TrainerCallback):
    """A custom callback that pushes logs to a queue for the UI."""
    def __init__(self, log_queue):
        self.log_queue = log_queue

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            # Format log dictionary nicely
            log_str = f"Step {state.global_step}: {json.dumps(logs)}\n"
            self.log_queue.put(log_str)

def get_username(token):
    """Retrieves the username from the HF token."""
    if not token:
        return None
    try:
        info = whoami(token=token)
        return info['name']
    except Exception:
        return None

def train_thread_target(
    token,
    dataset_id, 
    model_name, 
    num_layers, 
    n_embd, 
    n_head,
    context_length,
    epochs, 
    lr, 
    weight_decay,
    warmup_steps,
    batch_size,
    grad_accumulation,
    sample_limit,
    log_queue, 
    result_queue
):
    """
    Background thread for training and pushing to user profile.
    """
    try:
        # 0. Auth & Identity
        final_token = token or os.environ.get("HF_TOKEN")
        username = get_username(final_token)
        
        if not username:
            raise ValueError("Invalid or missing Hugging Face Token. Ensure the token is provided or set as HF_TOKEN secret.")
            
        # Target path is now the USER'S profile
        full_repo_id = f"{username}/{model_name}"
        log_queue.put(f"πŸš€ Initializing for user: {username}\n")
        log_queue.put(f"πŸ“¦ Target Repository: https://huggingface.co/{full_repo_id}\n")

        # Validation for Transformer logic
        if n_embd % n_head != 0:
            raise ValueError(f"Embedding dimension ({n_embd}) must be divisible by number of heads ({n_head}).")

        # 1. Load Dataset
        log_queue.put(f"πŸ“š Loading dataset: {dataset_id} (Limit: {sample_limit})...\n")
        try:
            # We use the train split; user can specify limit
            dataset = load_dataset(dataset_id, split=f"train[:{int(sample_limit)}]")
        except Exception as e:
            raise ValueError(f"Error loading dataset: {e}")

        # Auto-detect text column
        text_column = "text"
        if "text" not in dataset.column_names:
            for col in dataset.column_names:
                if isinstance(dataset[0][col], str):
                    text_column = col
                    break
        
        log_queue.put(f"πŸ” Using text column: '{text_column}'\n")

        # 2. Tokenize
        log_queue.put("βœ‚οΈ Tokenizing data...\n")
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token

        def tokenize_function(examples):
            return tokenizer(
                examples[text_column], 
                padding="max_length", 
                truncation=True, 
                max_length=int(context_length)
            )

        tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)

        # 3. Initialize Model
        log_queue.put("πŸ—οΈ Building GPT-2 Architecture...\n")
        config = GPT2Config(
            vocab_size=len(tokenizer),
            n_positions=int(context_length),
            n_ctx=int(context_length),
            n_embd=int(n_embd),
            n_layer=int(num_layers),
            n_head=int(n_head),
        )
        model = GPT2LMHeadModel(config)

        # 4. Train
        log_queue.put("πŸ‹οΈ Starting Training Loop...\n")
        
        training_args = TrainingArguments(
            output_dir="./local_results",
            overwrite_output_dir=True,
            num_train_epochs=epochs,
            per_device_train_batch_size=int(batch_size),
            gradient_accumulation_steps=int(grad_accumulation),
            learning_rate=lr,
            weight_decay=weight_decay,
            warmup_steps=int(warmup_steps),
            logging_steps=10,
            save_strategy="no", 
            push_to_hub=False,
            report_to="none",
            use_cpu=not torch.cuda.is_available(),
            fp16=torch.cuda.is_available(),
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
            train_dataset=tokenized_datasets,
            callbacks=[LogQueueCallback(log_queue)]
        )

        trainer.train()

        # 5. Push to User's Personal Hub
        log_queue.put(f"☁️ Uploading model to your profile...\n")
        model.push_to_hub(full_repo_id, token=final_token)
        tokenizer.push_to_hub(full_repo_id, token=final_token)

        result_queue.put(f"πŸŽ‰ Success! Published to: https://huggingface.co/{full_repo_id}")
    
    except Exception as e:
        log_queue.put(f"❌ Error: {str(e)}\n")
        result_queue.put(None)

# --- Generator for UI updates ---

def train_and_push_generator(
    token, dataset_id, model_name, 
    num_layers, n_embd, n_head, context_length,
    epochs, lr, weight_decay, warmup_steps,
    batch_size, grad_accumulation, sample_limit
):
    effective_token = token or os.environ.get("HF_TOKEN")

    if not effective_token:
        yield "Error: No Hugging Face Token found. Please enter a 'Write' token below.", ""
        return
    
    log_queue = queue.Queue()
    result_queue = queue.Queue()
    
    t = threading.Thread(target=train_thread_target, args=(
        effective_token, dataset_id, model_name, 
        num_layers, n_embd, n_head, context_length,
        epochs, lr, weight_decay, warmup_steps,
        batch_size, grad_accumulation, sample_limit,
        log_queue, result_queue
    ))
    t.start()
    
    logs_history = ""
    while t.is_alive():
        while not log_queue.empty():
            logs_history += log_queue.get()
            yield logs_history, "Training in progress..."
        time.sleep(0.5)
        
    while not log_queue.empty():
        logs_history += log_queue.get()
        
    if not result_queue.empty():
        result = result_queue.get()
        yield logs_history, result or "Training failed. See logs."
    else:
        yield logs_history, "Process interrupted."

# --- UI Layout ---

with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate")) as demo:
    gr.Markdown("# πŸš€ Personal Auto-PreTrain")
    gr.Markdown("Configure a custom GPT-2 architecture and train it directly to **your personal** Hugging Face profile.")
    
    with gr.Row():
        hf_token = gr.Textbox(
            label="HF Write Token", 
            placeholder="hf_...", 
            type="password",
            info="Required to create the repo on your profile. Must have 'Write' permissions."
        )
        model_name_input = gr.Textbox(
            label="Model Name", 
            value="my-custom-gpt2",
            placeholder="e.g. tiny-stories-v1"
        )

    with gr.Tabs():
        with gr.TabItem("1. Data Selection"):
            with gr.Row():
                dataset_input = gr.Textbox(
                    label="Dataset ID", 
                    value="roneneldan/TinyStories",
                    placeholder="e.g. wikitext"
                )
                sample_limit = gr.Number(
                    label="Training Samples", 
                    value=500, 
                    precision=0
                )
            context_length = gr.Slider(
                minimum=64, maximum=1024, value=128, step=64, 
                label="Max Context Length"
            )

        with gr.TabItem("2. Architecture"):
            with gr.Row():
                layers = gr.Slider(minimum=1, maximum=12, value=2, step=1, label="Layers")
                embd = gr.Slider(minimum=64, maximum=1024, value=128, step=64, label="Embedding Dim")
            with gr.Row():
                heads = gr.Slider(minimum=2, maximum=16, value=4, step=2, label="Attention Heads")
                gr.Markdown("_Note: Embedding Dim must be divisible by Attention Heads._")

        with gr.TabItem("3. Training Settings"):
            with gr.Row():
                epochs = gr.Slider(minimum=1, maximum=20, value=1, step=1, label="Epochs")
                lr = gr.Number(label="Learning Rate", value=5e-4)
            with gr.Row():
                batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Batch Size")
                grad_accumulation = gr.Slider(minimum=1, maximum=16, value=1, step=1, label="Grad Accumulation")
            with gr.Row():
                weight_decay = gr.Slider(minimum=0.0, maximum=0.1, value=0.01, step=0.01, label="Weight Decay")
                warmup_steps = gr.Number(label="Warmup Steps", value=50, precision=0)

    train_btn = gr.Button("πŸ”₯ Start Training & Push to My Profile", variant="primary")
    
    with gr.Row():
        log_output = gr.Code(label="Training Progress", language="json", lines=12)
        status_output = gr.Textbox(label="Final Status", interactive=False)

    train_btn.click(
        fn=train_and_push_generator,
        inputs=[
            hf_token, dataset_input, model_name_input, 
            layers, embd, heads, context_length,
            epochs, lr, weight_decay, warmup_steps,
            batch_size, grad_accumulation, sample_limit
        ],
        outputs=[log_output, status_output]
    )

if __name__ == "__main__":
    demo.launch()