AxionLab-official commited on
Commit
fd17ee1
·
verified ·
1 Parent(s): 5d0e06f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +776 -0
app.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════╗
3
+ ║ 🧪 Fine-Tuning Studio — HuggingFace Space ║
4
+ ║ Suporta: CPU / CPU Upgrade / T4 / A10G / A100 ║
5
+ ║ Modos: LoRA, QLoRA, Full Fine-Tuning ║
6
+ ║ Pós: Chat embutido + Download dos pesos ║
7
+ ╚══════════════════════════════════════════════════════════════╝
8
+ """
9
+
10
+ import os, gc, json, math, shutil, threading, time, logging
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import gradio as gr
15
+ import torch
16
+ import pandas as pd
17
+ from datasets import load_dataset, Dataset
18
+ from transformers import (
19
+ AutoTokenizer,
20
+ AutoModelForCausalLM,
21
+ TrainingArguments,
22
+ Trainer,
23
+ DataCollatorForLanguageModeling,
24
+ BitsAndBytesConfig,
25
+ GenerationConfig,
26
+ TrainerCallback,
27
+ )
28
+
29
+ # ── PEFT é opcional; detectado em tempo de execução ───────────
30
+ try:
31
+ from peft import (
32
+ LoraConfig,
33
+ get_peft_model,
34
+ prepare_model_for_kbit_training,
35
+ PeftModel,
36
+ TaskType,
37
+ )
38
+ PEFT_AVAILABLE = True
39
+ except ImportError:
40
+ PEFT_AVAILABLE = False
41
+
42
+ logging.basicConfig(level=logging.INFO)
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # ──────────────────────────────────────────────────────────────
46
+ # HARDWARE DETECTION
47
+ # ──────────────────────────────────────────────────────────────
48
+
49
+ def detect_hardware() -> dict:
50
+ info = {"device": "cpu", "vram_gb": 0, "gpu_name": "N/A", "bf16": False}
51
+ if torch.cuda.is_available():
52
+ info["device"] = "cuda"
53
+ info["vram_gb"] = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1)
54
+ info["gpu_name"] = torch.cuda.get_device_name(0)
55
+ info["bf16"] = torch.cuda.is_bf16_supported()
56
+ return info
57
+
58
+ HW = detect_hardware()
59
+
60
+ def hw_banner() -> str:
61
+ if HW["device"] == "cuda":
62
+ tier = "🟢 GPU" if HW["vram_gb"] >= 16 else "🟡 GPU (pequena)"
63
+ return (
64
+ f"{tier} · {HW['gpu_name']} · {HW['vram_gb']} GB VRAM | "
65
+ f"BF16: {'✅' if HW['bf16'] else '❌'} | "
66
+ f"PEFT/LoRA: {'✅' if PEFT_AVAILABLE else '❌ (instale peft)'}"
67
+ )
68
+ return (
69
+ f"🔵 CPU | Threads: {torch.get_num_threads()} | "
70
+ f"PEFT/LoRA: {'✅' if PEFT_AVAILABLE else '❌'}"
71
+ )
72
+
73
+ # ──────────────────────────────────────────────────────────────
74
+ # MODELO CATALOG (modelo_id, max_vram_recomendado_gb)
75
+ # ──────────────────────────────────────────────────────────────
76
+
77
+ MODEL_CATALOG = {
78
+ # Tiny — roda até em CPU
79
+ "TinyLlama 1.1B": ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 4),
80
+ "SmolLM 1.7B": ("HuggingFaceTB/SmolLM2-1.7B-Instruct", 6),
81
+ # Médio — T4 OK com QLoRA
82
+ "Mistral 7B": ("mistralai/Mistral-7B-Instruct-v0.2", 14),
83
+ "Llama 3.1 8B": ("meta-llama/Meta-Llama-3.1-8B-Instruct", 16),
84
+ "Gemma 2 9B": ("google/gemma-2-9b-it", 18),
85
+ # Grande — A10G / A100
86
+ "Llama 3.1 70B": ("meta-llama/Meta-Llama-3.1-70B-Instruct", 80),
87
+ "Mixtral 8x7B": ("mistralai/Mixtral-8x7B-Instruct-v0.1", 48),
88
+ }
89
+
90
+ def available_models() -> list[str]:
91
+ """Filtra modelos que cabem no hardware atual."""
92
+ vram = HW["vram_gb"] if HW["device"] == "cuda" else 2
93
+ return [name for name, (_, req) in MODEL_CATALOG.items() if req <= max(vram * 1.2, 6)]
94
+
95
+ # ──────────────────────────────────────────────────────────────
96
+ # ESTADO GLOBAL DO TREINAMENTO
97
+ # ──────────────────────────────────────────────────────────────
98
+
99
+ class TrainingState:
100
+ def __init__(self):
101
+ self.reset()
102
+
103
+ def reset(self):
104
+ self.running = False
105
+ self.cancelled = False
106
+ self.logs: list = []
107
+ self.progress: int = 0
108
+ self.total_steps = 0
109
+ self.model = None
110
+ self.tokenizer = None
111
+ self.output_dir = Path("./trained_model")
112
+ self.error: Optional[str] = None
113
+
114
+ def log(self, msg: str):
115
+ ts = time.strftime("%H:%M:%S")
116
+ self.logs.append(f"[{ts}] {msg}")
117
+ logger.info(msg)
118
+
119
+ def log_box(self) -> str:
120
+ return "\n".join(self.logs[-60:]) # últimas 60 linhas
121
+
122
+ STATE = TrainingState()
123
+
124
+ # ──────────────────────────────────────────────────────────────
125
+ # CALLBACK PARA PROGRESSO EM TEMPO REAL
126
+ # ──────────────────────────────────────────────────────────────
127
+
128
+ class ProgressCallback(TrainerCallback):
129
+ def on_train_begin(self, args, state, control, **kwargs):
130
+ STATE.total_steps = state.max_steps
131
+ STATE.log(f"▶ Treinamento iniciado — {state.max_steps} steps")
132
+
133
+ def on_log(self, args, state, control, logs=None, **kwargs):
134
+ if logs:
135
+ loss = logs.get("loss", "—")
136
+ lr = logs.get("learning_rate", "—")
137
+ step = state.global_step
138
+ STATE.progress = step
139
+ STATE.log(f"Step {step}/{STATE.total_steps} loss={loss} lr={lr}")
140
+
141
+ def on_step_end(self, args, state, control, **kwargs):
142
+ if STATE.cancelled:
143
+ control.should_training_stop = True
144
+
145
+ def on_train_end(self, args, state, control, **kwargs):
146
+ STATE.log("✅ Treinamento concluído!")
147
+
148
+ # ──────────────────────────────────────────────────────────────
149
+ # DATASET HELPERS
150
+ # ──────────────────────────────────────────────────────────────
151
+
152
+ def load_user_dataset(source: str, hf_dataset: str, uploaded_file) -> Dataset:
153
+ """Carrega dataset de múltiplas fontes."""
154
+ if source == "HuggingFace Hub" and hf_dataset.strip():
155
+ ds = load_dataset(hf_dataset.strip(), split="train")
156
+ return ds
157
+
158
+ if source == "Upload CSV/JSONL" and uploaded_file is not None:
159
+ path = uploaded_file.name
160
+ if path.endswith(".csv"):
161
+ df = pd.read_csv(path)
162
+ else:
163
+ df = pd.read_json(path, lines=True)
164
+ return Dataset.from_pandas(df)
165
+
166
+ # Fallback: dataset de exemplo embutido
167
+ examples = [
168
+ {"text": "Instrução: Explique o que é machine learning.\nResposta: Machine learning é..."},
169
+ {"text": "Instrução: O que é uma rede neural?\nResposta: Uma rede neural é..."},
170
+ {"text": "Instrução: Como funciona o backpropagation?\nResposta: O backpropagation..."},
171
+ ]
172
+ return Dataset.from_list(examples)
173
+
174
+ def tokenize_dataset(dataset: Dataset, tokenizer, max_length: int) -> Dataset:
175
+ text_col = next(
176
+ (c for c in ["text", "prompt", "instruction", "content"] if c in dataset.column_names),
177
+ dataset.column_names[0],
178
+ )
179
+
180
+ def tokenize(examples):
181
+ return tokenizer(
182
+ examples[text_col],
183
+ truncation=True,
184
+ max_length=max_length,
185
+ padding="max_length",
186
+ )
187
+
188
+ return dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
189
+
190
+ # ──────────────────────────────────────────────────────────────
191
+ # CORE: CARREGA MODELO
192
+ # ──────────────────────────────────────────────────────────────
193
+
194
+ def load_model_and_tokenizer(model_name: str, ft_mode: str):
195
+ model_id, _ = MODEL_CATALOG[model_name]
196
+ STATE.log(f"⬇ Carregando tokenizer: {model_id}")
197
+
198
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
199
+ if tokenizer.pad_token is None:
200
+ tokenizer.pad_token = tokenizer.eos_token
201
+
202
+ # Configuração de quantização
203
+ bnb_cfg = None
204
+ load_in_4bit = False
205
+ load_in_8bit = False
206
+
207
+ if HW["device"] == "cuda" and ft_mode == "QLoRA":
208
+ if not PEFT_AVAILABLE:
209
+ raise RuntimeError("Instale `peft` e `bitsandbytes` para QLoRA.")
210
+ STATE.log("🔧 Configurando 4-bit NF4 (QLoRA)...")
211
+ bnb_cfg = BitsAndBytesConfig(
212
+ load_in_4bit=True,
213
+ bnb_4bit_use_double_quant=True,
214
+ bnb_4bit_quant_type="nf4",
215
+ bnb_4bit_compute_dtype=torch.bfloat16 if HW["bf16"] else torch.float16,
216
+ )
217
+ load_in_4bit = True
218
+
219
+ STATE.log(f"⬇ Carregando modelo ({ft_mode})...")
220
+ model = AutoModelForCausalLM.from_pretrained(
221
+ model_id,
222
+ quantization_config=bnb_cfg,
223
+ device_map="auto" if HW["device"] == "cuda" else None,
224
+ torch_dtype=torch.float16 if (HW["device"] == "cuda" and not HW["bf16"]) else "auto",
225
+ trust_remote_code=True,
226
+ )
227
+
228
+ if ft_mode in ("LoRA", "QLoRA") and PEFT_AVAILABLE:
229
+ if load_in_4bit:
230
+ model = prepare_model_for_kbit_training(model)
231
+
232
+ lora_cfg = LoraConfig(
233
+ r=16,
234
+ lora_alpha=32,
235
+ lora_dropout=0.05,
236
+ bias="none",
237
+ task_type=TaskType.CAUSAL_LM,
238
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
239
+ )
240
+ model = get_peft_model(model, lora_cfg)
241
+ model.print_trainable_parameters()
242
+ trainable, total, pct = model.get_nb_trainable_parameters()
243
+ STATE.log(f"📊 Parâmetros treináveis: {trainable:,} / {total:,} ({pct:.2f}%)")
244
+
245
+ elif ft_mode == "Full Fine-Tuning":
246
+ STATE.log("⚠ Full fine-tuning: todos os pesos serão atualizados.")
247
+ if HW["device"] != "cuda":
248
+ STATE.log("⚠ Full fine-tuning em CPU será MUITO lento.")
249
+
250
+ return model, tokenizer
251
+
252
+ # ──────────────────────────────────────────────────────────────
253
+ # CORE: TREINAMENTO
254
+ # ──────────────────────────────────────────────────────────────
255
+
256
+ def run_training(
257
+ model_name, ft_mode,
258
+ dataset_source, hf_dataset, uploaded_file,
259
+ epochs, batch_size, learning_rate, max_length,
260
+ warmup_steps, weight_decay, grad_accum,
261
+ ):
262
+ try:
263
+ STATE.reset()
264
+ STATE.running = True
265
+ STATE.output_dir = Path(f"./trained_{model_name.replace(' ', '_')}_{ft_mode}")
266
+ STATE.output_dir.mkdir(parents=True, exist_ok=True)
267
+
268
+ STATE.log(f"🖥 Hardware: {HW['gpu_name'] if HW['device']=='cuda' else 'CPU'}")
269
+ STATE.log(f"📦 Modelo: {model_name} | Modo: {ft_mode}")
270
+
271
+ # 1. Dataset
272
+ STATE.log("📂 Carregando dataset...")
273
+ raw_ds = load_user_dataset(dataset_source, hf_dataset, uploaded_file)
274
+ STATE.log(f"✅ Dataset: {len(raw_ds)} exemplos")
275
+
276
+ # 2. Modelo
277
+ model, tokenizer = load_model_and_tokenizer(model_name, ft_mode)
278
+ STATE.tokenizer = tokenizer
279
+
280
+ # 3. Tokenização
281
+ STATE.log("🔤 Tokenizando dataset...")
282
+ tokenized = tokenize_dataset(raw_ds, tokenizer, max_length)
283
+ tokenized = tokenized.train_test_split(test_size=0.05, seed=42)
284
+
285
+ # 4. TrainingArguments
286
+ use_fp16 = HW["device"] == "cuda" and not HW["bf16"]
287
+ use_bf16 = HW["device"] == "cuda" and HW["bf16"]
288
+
289
+ args = TrainingArguments(
290
+ output_dir=str(STATE.output_dir),
291
+ num_train_epochs=epochs,
292
+ per_device_train_batch_size=batch_size,
293
+ gradient_accumulation_steps=grad_accum,
294
+ learning_rate=learning_rate,
295
+ warmup_steps=warmup_steps,
296
+ weight_decay=weight_decay,
297
+ fp16=use_fp16,
298
+ bf16=use_bf16,
299
+ logging_steps=5,
300
+ save_steps=50,
301
+ save_total_limit=2,
302
+ eval_strategy="steps",
303
+ eval_steps=50,
304
+ load_best_model_at_end=True,
305
+ report_to="none",
306
+ dataloader_pin_memory=(HW["device"] == "cuda"),
307
+ )
308
+
309
+ collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
310
+
311
+ trainer = Trainer(
312
+ model=model,
313
+ args=args,
314
+ train_dataset=tokenized["train"],
315
+ eval_dataset=tokenized["test"],
316
+ data_collator=collator,
317
+ callbacks=[ProgressCallback()],
318
+ )
319
+
320
+ STATE.log("🚀 Iniciando treinamento...")
321
+ trainer.train()
322
+
323
+ if not STATE.cancelled:
324
+ STATE.log("💾 Salvando modelo...")
325
+ model.save_pretrained(str(STATE.output_dir))
326
+ tokenizer.save_pretrained(str(STATE.output_dir))
327
+
328
+ # Salva metadados
329
+ meta = {
330
+ "base_model": MODEL_CATALOG[model_name][0],
331
+ "ft_mode": ft_mode,
332
+ "epochs": epochs,
333
+ "learning_rate": learning_rate,
334
+ "dataset_source": dataset_source,
335
+ "hardware": HW,
336
+ }
337
+ (STATE.output_dir / "training_meta.json").write_text(json.dumps(meta, indent=2))
338
+
339
+ STATE.model = model
340
+ STATE.log(f"🎉 Modelo salvo em: {STATE.output_dir}")
341
+
342
+ except Exception as e:
343
+ STATE.error = str(e)
344
+ STATE.log(f"❌ Erro: {e}")
345
+ logger.exception(e)
346
+ finally:
347
+ STATE.running = False
348
+
349
+ # ──────────────────────────────────────────────────────────────
350
+ # CHAT COM MODELO TREINADO
351
+ # ──────────────────────────────────────────────────────────────
352
+
353
+ def chat_with_model(message: str, history: list, max_new_tokens: int, temperature: float):
354
+ if STATE.model is None or STATE.tokenizer is None:
355
+ return history + [[message, "⚠ Nenhum modelo treinado disponível. Complete o treinamento primeiro."]]
356
+
357
+ prompt = ""
358
+ for user_msg, bot_msg in history:
359
+ prompt += f"Usuário: {user_msg}\nAssistente: {bot_msg}\n"
360
+ prompt += f"Usuário: {message}\nAssistente:"
361
+
362
+ inputs = STATE.tokenizer(prompt, return_tensors="pt")
363
+ if HW["device"] == "cuda":
364
+ inputs = {k: v.cuda() for k, v in inputs.items()}
365
+
366
+ with torch.no_grad():
367
+ outputs = STATE.model.generate(
368
+ **inputs,
369
+ max_new_tokens=max_new_tokens,
370
+ temperature=temperature,
371
+ do_sample=temperature > 0,
372
+ pad_token_id=STATE.tokenizer.eos_token_id,
373
+ )
374
+
375
+ gen_tokens = outputs[0][inputs["input_ids"].shape[1]:]
376
+ response = STATE.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
377
+ return history + [[message, response]]
378
+
379
+ # ──────────────────────────────────────────────────────────────
380
+ # ZIP & DOWNLOAD
381
+ # ──────────────────────────────────────────────────────────────
382
+
383
+ def create_download_zip() -> Optional[str]:
384
+ if not STATE.output_dir.exists():
385
+ return None
386
+ zip_path = Path("./model_export.zip")
387
+ shutil.make_archive("model_export", "zip", str(STATE.output_dir))
388
+ return str(zip_path)
389
+
390
+ # ──────────────────────────────────────────────────────────────
391
+ # GRADIO UI
392
+ # ──────────────────────────────────────────────────────────────
393
+
394
+ CSS = """
395
+ @import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Syne:wght@400;600;800&display=swap');
396
+
397
+ * { box-sizing: border-box; }
398
+
399
+ body, .gradio-container {
400
+ background: #0a0a0f !important;
401
+ color: #e8e6f0 !important;
402
+ font-family: 'Syne', sans-serif !important;
403
+ }
404
+
405
+ .gr-panel, .gr-box, .gr-block { background: transparent !important; }
406
+
407
+ /* Header */
408
+ .studio-header {
409
+ text-align: center;
410
+ padding: 2rem 0 1rem;
411
+ border-bottom: 1px solid #2a2a3a;
412
+ margin-bottom: 1.5rem;
413
+ }
414
+ .studio-header h1 {
415
+ font-family: 'Syne', sans-serif;
416
+ font-weight: 800;
417
+ font-size: 2.2rem;
418
+ letter-spacing: -0.02em;
419
+ color: #fff;
420
+ margin: 0;
421
+ }
422
+ .studio-header h1 span { color: #7c6af7; }
423
+ .studio-header p {
424
+ font-family: 'Space Mono', monospace;
425
+ font-size: 0.75rem;
426
+ color: #6b6888;
427
+ margin-top: 0.4rem;
428
+ letter-spacing: 0.08em;
429
+ }
430
+
431
+ /* Hardware badge */
432
+ .hw-badge {
433
+ font-family: 'Space Mono', monospace;
434
+ font-size: 0.72rem;
435
+ background: #12121e;
436
+ border: 1px solid #2a2a3a;
437
+ border-radius: 6px;
438
+ padding: 0.5rem 1rem;
439
+ color: #8a88a8;
440
+ text-align: center;
441
+ margin-bottom: 1.2rem;
442
+ }
443
+
444
+ /* Tabs */
445
+ .tab-nav button {
446
+ font-family: 'Syne', sans-serif !important;
447
+ font-weight: 600 !important;
448
+ font-size: 0.85rem !important;
449
+ letter-spacing: 0.04em !important;
450
+ color: #6b6888 !important;
451
+ background: transparent !important;
452
+ border: none !important;
453
+ border-bottom: 2px solid transparent !important;
454
+ padding: 0.5rem 1.2rem !important;
455
+ }
456
+ .tab-nav button.selected {
457
+ color: #7c6af7 !important;
458
+ border-bottom-color: #7c6af7 !important;
459
+ }
460
+
461
+ /* Inputs */
462
+ .gr-input, .gr-dropdown select, textarea {
463
+ background: #12121e !important;
464
+ border: 1px solid #2a2a3a !important;
465
+ color: #e8e6f0 !important;
466
+ border-radius: 8px !important;
467
+ font-family: 'Space Mono', monospace !important;
468
+ font-size: 0.8rem !important;
469
+ }
470
+ .gr-input:focus, textarea:focus {
471
+ border-color: #7c6af7 !important;
472
+ box-shadow: 0 0 0 2px rgba(124,106,247,0.15) !important;
473
+ }
474
+
475
+ /* Buttons */
476
+ .gr-button {
477
+ font-family: 'Syne', sans-serif !important;
478
+ font-weight: 600 !important;
479
+ border-radius: 8px !important;
480
+ transition: all 0.15s !important;
481
+ }
482
+ .gr-button.primary {
483
+ background: #7c6af7 !important;
484
+ border: none !important;
485
+ color: #fff !important;
486
+ }
487
+ .gr-button.primary:hover { background: #6a58e0 !important; transform: translateY(-1px); }
488
+ .gr-button.secondary {
489
+ background: transparent !important;
490
+ border: 1px solid #2a2a3a !important;
491
+ color: #8a88a8 !important;
492
+ }
493
+ .gr-button.stop { background: #c0392b !important; color: #fff !important; border: none !important; }
494
+
495
+ /* Log box */
496
+ .log-box textarea {
497
+ font-family: 'Space Mono', monospace !important;
498
+ font-size: 0.72rem !important;
499
+ line-height: 1.6 !important;
500
+ background: #07070f !important;
501
+ border: 1px solid #1e1e2e !important;
502
+ color: #a8e6cf !important;
503
+ }
504
+
505
+ /* Progress bar */
506
+ .progress-bar-wrap .progress-bar { background: #7c6af7 !important; }
507
+
508
+ /* Slider labels */
509
+ .gr-form label {
510
+ font-family: 'Syne', sans-serif !important;
511
+ font-size: 0.82rem !important;
512
+ color: #8a88a8 !important;
513
+ font-weight: 600 !important;
514
+ }
515
+
516
+ /* Section labels */
517
+ .section-label {
518
+ font-family: 'Space Mono', monospace;
519
+ font-size: 0.65rem;
520
+ letter-spacing: 0.12em;
521
+ color: #4a4868;
522
+ text-transform: uppercase;
523
+ margin: 1rem 0 0.4rem;
524
+ }
525
+
526
+ /* Chat bubbles */
527
+ .message.user div { background: #1e1e30 !important; border-radius: 10px !important; }
528
+ .message.bot div { background: #12121e !important; border-radius: 10px !important; border: 1px solid #2a2a3a !important; }
529
+ """
530
+
531
+ def build_ui():
532
+ models = available_models()
533
+ if not models:
534
+ models = list(MODEL_CATALOG.keys())[:2] # fallback
535
+
536
+ with gr.Blocks(css=CSS, title="Fine-Tuning Studio") as demo:
537
+
538
+ # ── HEADER ────────────────────────────────────────────
539
+ gr.HTML(f"""
540
+ <div class="studio-header">
541
+ <h1>🧪 Fine-Tuning <span>Studio</span></h1>
542
+ <p>TREINE · CONVERTA · CONVERSE · EXPORTE</p>
543
+ </div>
544
+ <div class="hw-badge">{hw_banner()}</div>
545
+ """)
546
+
547
+ # ── TABS ──────────────────────────────────────────────
548
+ with gr.Tabs(elem_classes="tab-nav"):
549
+
550
+ # ════════════════════════════════
551
+ # TAB 1 — CONFIGURAR & TREINAR
552
+ # ════════════════════════════════
553
+ with gr.Tab("⚙️ Treinar"):
554
+ with gr.Row():
555
+
556
+ # Coluna esquerda — config
557
+ with gr.Column(scale=1):
558
+ gr.HTML('<div class="section-label">modelo</div>')
559
+ model_dd = gr.Dropdown(
560
+ choices=models,
561
+ value=models[0],
562
+ label="Modelo base",
563
+ interactive=True,
564
+ )
565
+ ft_mode_dd = gr.Dropdown(
566
+ choices=["LoRA", "QLoRA", "Full Fine-Tuning"],
567
+ value="LoRA" if PEFT_AVAILABLE else "Full Fine-Tuning",
568
+ label="Modo de fine-tuning",
569
+ )
570
+
571
+ gr.HTML('<div class="section-label">dataset</div>')
572
+ ds_source = gr.Radio(
573
+ choices=["HuggingFace Hub", "Upload CSV/JSONL", "Exemplo embutido"],
574
+ value="Exemplo embutido",
575
+ label="Fonte do dataset",
576
+ )
577
+ hf_ds_input = gr.Textbox(
578
+ placeholder="ex: tatsu-lab/alpaca",
579
+ label="Dataset ID (Hub)",
580
+ visible=False,
581
+ )
582
+ upload_file = gr.File(
583
+ label="CSV ou JSONL",
584
+ file_types=[".csv", ".jsonl"],
585
+ visible=False,
586
+ )
587
+
588
+ def toggle_ds(source):
589
+ return (
590
+ gr.update(visible=source == "HuggingFace Hub"),
591
+ gr.update(visible=source == "Upload CSV/JSONL"),
592
+ )
593
+ ds_source.change(toggle_ds, ds_source, [hf_ds_input, upload_file])
594
+
595
+ gr.HTML('<div class="section-label">hiperparâmetros</div>')
596
+ epochs_sl = gr.Slider(1, 10, value=3, step=1, label="Épocas")
597
+ batch_sl = gr.Slider(1, 16, value=2, step=1, label="Batch size")
598
+ lr_sl = gr.Slider(1e-5, 5e-4, value=2e-4, step=1e-5, label="Learning rate")
599
+ max_len_sl = gr.Slider(64, 2048, value=512, step=64, label="Max length (tokens)")
600
+ grad_acc_sl = gr.Slider(1, 16, value=4, step=1, label="Grad. accumulation")
601
+ warmup_sl = gr.Slider(0, 200, value=10, step=5, label="Warmup steps")
602
+ wd_sl = gr.Slider(0, 0.1, value=0.01, step=0.005, label="Weight decay")
603
+
604
+ # Coluna direita — logs
605
+ with gr.Column(scale=1):
606
+ gr.HTML('<div class="section-label">log de treinamento</div>')
607
+ log_box = gr.Textbox(
608
+ label="",
609
+ lines=24,
610
+ max_lines=24,
611
+ interactive=False,
612
+ elem_classes="log-box",
613
+ placeholder="O log aparecerá aqui quando o treinamento iniciar...",
614
+ )
615
+ progress = gr.Slider(
616
+ 0, 100, value=0, label="Progresso (%)", interactive=False
617
+ )
618
+
619
+ with gr.Row():
620
+ train_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary", scale=3)
621
+ cancel_btn = gr.Button("⏹ Cancelar", variant="stop", scale=1)
622
+
623
+ status_md = gr.Markdown("")
624
+
625
+ # ── Lógica de treinamento ──────────────────
626
+ def start_training(
627
+ model_name, ft_mode,
628
+ ds_source, hf_ds, up_file,
629
+ epochs, batch, lr, max_len,
630
+ warmup, wd, grad_acc,
631
+ ):
632
+ if STATE.running:
633
+ return "⚠ Treinamento já em andamento."
634
+
635
+ thread = threading.Thread(
636
+ target=run_training,
637
+ args=(
638
+ model_name, ft_mode,
639
+ ds_source, hf_ds, up_file,
640
+ epochs, batch, lr, max_len,
641
+ warmup, wd, grad_acc,
642
+ ),
643
+ daemon=True,
644
+ )
645
+ thread.start()
646
+ return "▶ Treinamento iniciado..."
647
+
648
+ def cancel_training():
649
+ STATE.cancelled = True
650
+ return "⏹ Cancelamento solicitado."
651
+
652
+ def poll_logs():
653
+ """Polling a cada 2s para atualizar log e progresso."""
654
+ while True:
655
+ pct = 0
656
+ if STATE.total_steps > 0:
657
+ pct = min(100, int(STATE.progress / STATE.total_steps * 100))
658
+ yield STATE.log_box(), pct
659
+ time.sleep(2)
660
+
661
+ train_btn.click(
662
+ start_training,
663
+ inputs=[
664
+ model_dd, ft_mode_dd,
665
+ ds_source, hf_ds_input, upload_file,
666
+ epochs_sl, batch_sl, lr_sl, max_len_sl,
667
+ warmup_sl, wd_sl, grad_acc_sl,
668
+ ],
669
+ outputs=status_md,
670
+ )
671
+ cancel_btn.click(cancel_training, outputs=status_md)
672
+
673
+ demo.load(poll_logs, outputs=[log_box, progress], every=2)
674
+
675
+ # ════════════════════════════════
676
+ # TAB 2 — CHAT
677
+ # ════════════════════════════════
678
+ with gr.Tab("💬 Chat"):
679
+ gr.Markdown(
680
+ "**Converse com o modelo treinado.** Complete o treinamento na aba anterior primeiro.",
681
+ elem_id="chat-hint",
682
+ )
683
+ chatbot = gr.Chatbot(height=440, label="Conversa")
684
+
685
+ with gr.Row():
686
+ chat_input = gr.Textbox(
687
+ placeholder="Digite sua mensagem...",
688
+ label="",
689
+ scale=4,
690
+ )
691
+ send_btn = gr.Button("Enviar", variant="primary", scale=1)
692
+
693
+ with gr.Accordion("⚙️ Parâmetros de geração", open=False):
694
+ max_new_sl = gr.Slider(32, 1024, value=256, step=32, label="Max new tokens")
695
+ temp_sl = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
696
+
697
+ clear_btn = gr.Button("🗑 Limpar conversa", variant="secondary")
698
+
699
+ send_btn.click(
700
+ chat_with_model,
701
+ inputs=[chat_input, chatbot, max_new_sl, temp_sl],
702
+ outputs=chatbot,
703
+ )
704
+ chat_input.submit(
705
+ chat_with_model,
706
+ inputs=[chat_input, chatbot, max_new_sl, temp_sl],
707
+ outputs=chatbot,
708
+ )
709
+ clear_btn.click(lambda: [], outputs=chatbot)
710
+
711
+ # ════════════════════════════════
712
+ # TAB 3 — EXPORTAR
713
+ # ════════════════════════════════
714
+ with gr.Tab("📦 Exportar"):
715
+ gr.Markdown("### Download dos pesos treinados")
716
+ gr.Markdown(
717
+ "Após o treinamento, clique abaixo para gerar um `.zip` com todos os pesos e metadados."
718
+ )
719
+
720
+ with gr.Row():
721
+ zip_btn = gr.Button("📦 Gerar ZIP", variant="primary")
722
+ download_out = gr.File(label="Download", interactive=False)
723
+
724
+ export_status = gr.Markdown("")
725
+
726
+ def generate_zip():
727
+ path = create_download_zip()
728
+ if path:
729
+ return path, "✅ ZIP gerado! Clique para baixar."
730
+ return None, "⚠ Nenhum modelo treinado encontrado. Complete o treinamento primeiro."
731
+
732
+ zip_btn.click(generate_zip, outputs=[download_out, export_status])
733
+
734
+ gr.Markdown("---")
735
+ gr.Markdown("### Push para HuggingFace Hub")
736
+ gr.Markdown(
737
+ "Para fazer push do modelo para o Hub, configure o `HF_TOKEN` nas **Secrets** do Space "
738
+ "e use `model.push_to_hub('seu-usuario/nome-do-modelo')` no terminal."
739
+ )
740
+
741
+ # Metadados do treino
742
+ gr.HTML('<div class="section-label">metadados do treino</div>')
743
+
744
+ def get_meta():
745
+ meta_file = STATE.output_dir / "training_meta.json"
746
+ if meta_file.exists():
747
+ return meta_file.read_text()
748
+ return "Sem metadados ainda."
749
+
750
+ meta_box = gr.Code(label="training_meta.json", language="json", interactive=False)
751
+ refresh_meta_btn = gr.Button("🔄 Atualizar metadados", variant="secondary")
752
+ refresh_meta_btn.click(get_meta, outputs=meta_box)
753
+
754
+ # Footer
755
+ gr.HTML("""
756
+ <div style="text-align:center; margin-top:2rem; font-family:'Space Mono',monospace;
757
+ font-size:0.65rem; color:#3a3858; letter-spacing:0.1em;">
758
+ FINE-TUNING STUDIO · HUGGINGFACE SPACE · ADAPTA-SE AO HARDWARE DISPONÍVEL
759
+ </div>
760
+ """)
761
+
762
+ return demo
763
+
764
+
765
+ # ──────────────────────────────────────────────────────────────
766
+ # ENTRY POINT
767
+ # ──────────────────────────────────────────────────────────────
768
+
769
+ if __name__ == "__main__":
770
+ app = build_ui()
771
+ app.launch(
772
+ server_name="0.0.0.0",
773
+ server_port=7860,
774
+ share=False,
775
+ show_error=True,
776
+ )