Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
╔══════════════════════════════════════════════════════════════╗
|
| 3 |
+
║ 🧪 Fine-Tuning Studio — HuggingFace Space ║
|
| 4 |
+
║ Suporta: CPU / CPU Upgrade / T4 / A10G / A100 ║
|
| 5 |
+
║ Modos: LoRA, QLoRA, Full Fine-Tuning ║
|
| 6 |
+
║ Pós: Chat embutido + Download dos pesos ║
|
| 7 |
+
╚══════════════════════════════════════════════════════════════╝
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os, gc, json, math, shutil, threading, time, logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import torch
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from datasets import load_dataset, Dataset
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoTokenizer,
|
| 20 |
+
AutoModelForCausalLM,
|
| 21 |
+
TrainingArguments,
|
| 22 |
+
Trainer,
|
| 23 |
+
DataCollatorForLanguageModeling,
|
| 24 |
+
BitsAndBytesConfig,
|
| 25 |
+
GenerationConfig,
|
| 26 |
+
TrainerCallback,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# ── PEFT é opcional; detectado em tempo de execução ───────────
|
| 30 |
+
try:
|
| 31 |
+
from peft import (
|
| 32 |
+
LoraConfig,
|
| 33 |
+
get_peft_model,
|
| 34 |
+
prepare_model_for_kbit_training,
|
| 35 |
+
PeftModel,
|
| 36 |
+
TaskType,
|
| 37 |
+
)
|
| 38 |
+
PEFT_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
PEFT_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
logging.basicConfig(level=logging.INFO)
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
# ──────────────────────────────────────────────────────────────
|
| 46 |
+
# HARDWARE DETECTION
|
| 47 |
+
# ──────────────────────────────────────────────────────────────
|
| 48 |
+
|
| 49 |
+
def detect_hardware() -> dict:
|
| 50 |
+
info = {"device": "cpu", "vram_gb": 0, "gpu_name": "N/A", "bf16": False}
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
info["device"] = "cuda"
|
| 53 |
+
info["vram_gb"] = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1)
|
| 54 |
+
info["gpu_name"] = torch.cuda.get_device_name(0)
|
| 55 |
+
info["bf16"] = torch.cuda.is_bf16_supported()
|
| 56 |
+
return info
|
| 57 |
+
|
| 58 |
+
HW = detect_hardware()
|
| 59 |
+
|
| 60 |
+
def hw_banner() -> str:
|
| 61 |
+
if HW["device"] == "cuda":
|
| 62 |
+
tier = "🟢 GPU" if HW["vram_gb"] >= 16 else "🟡 GPU (pequena)"
|
| 63 |
+
return (
|
| 64 |
+
f"{tier} · {HW['gpu_name']} · {HW['vram_gb']} GB VRAM | "
|
| 65 |
+
f"BF16: {'✅' if HW['bf16'] else '❌'} | "
|
| 66 |
+
f"PEFT/LoRA: {'✅' if PEFT_AVAILABLE else '❌ (instale peft)'}"
|
| 67 |
+
)
|
| 68 |
+
return (
|
| 69 |
+
f"🔵 CPU | Threads: {torch.get_num_threads()} | "
|
| 70 |
+
f"PEFT/LoRA: {'✅' if PEFT_AVAILABLE else '❌'}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# ──────────────────────────────────────────────────────────────
|
| 74 |
+
# MODELO CATALOG (modelo_id, max_vram_recomendado_gb)
|
| 75 |
+
# ──────────────────────────────────────────────────────────────
|
| 76 |
+
|
| 77 |
+
MODEL_CATALOG = {
|
| 78 |
+
# Tiny — roda até em CPU
|
| 79 |
+
"TinyLlama 1.1B": ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 4),
|
| 80 |
+
"SmolLM 1.7B": ("HuggingFaceTB/SmolLM2-1.7B-Instruct", 6),
|
| 81 |
+
# Médio — T4 OK com QLoRA
|
| 82 |
+
"Mistral 7B": ("mistralai/Mistral-7B-Instruct-v0.2", 14),
|
| 83 |
+
"Llama 3.1 8B": ("meta-llama/Meta-Llama-3.1-8B-Instruct", 16),
|
| 84 |
+
"Gemma 2 9B": ("google/gemma-2-9b-it", 18),
|
| 85 |
+
# Grande — A10G / A100
|
| 86 |
+
"Llama 3.1 70B": ("meta-llama/Meta-Llama-3.1-70B-Instruct", 80),
|
| 87 |
+
"Mixtral 8x7B": ("mistralai/Mixtral-8x7B-Instruct-v0.1", 48),
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def available_models() -> list[str]:
|
| 91 |
+
"""Filtra modelos que cabem no hardware atual."""
|
| 92 |
+
vram = HW["vram_gb"] if HW["device"] == "cuda" else 2
|
| 93 |
+
return [name for name, (_, req) in MODEL_CATALOG.items() if req <= max(vram * 1.2, 6)]
|
| 94 |
+
|
| 95 |
+
# ──────────────────────────────────────────────────────────────
|
| 96 |
+
# ESTADO GLOBAL DO TREINAMENTO
|
| 97 |
+
# ──────────────────────────────────────────────────────────────
|
| 98 |
+
|
| 99 |
+
class TrainingState:
|
| 100 |
+
def __init__(self):
|
| 101 |
+
self.reset()
|
| 102 |
+
|
| 103 |
+
def reset(self):
|
| 104 |
+
self.running = False
|
| 105 |
+
self.cancelled = False
|
| 106 |
+
self.logs: list = []
|
| 107 |
+
self.progress: int = 0
|
| 108 |
+
self.total_steps = 0
|
| 109 |
+
self.model = None
|
| 110 |
+
self.tokenizer = None
|
| 111 |
+
self.output_dir = Path("./trained_model")
|
| 112 |
+
self.error: Optional[str] = None
|
| 113 |
+
|
| 114 |
+
def log(self, msg: str):
|
| 115 |
+
ts = time.strftime("%H:%M:%S")
|
| 116 |
+
self.logs.append(f"[{ts}] {msg}")
|
| 117 |
+
logger.info(msg)
|
| 118 |
+
|
| 119 |
+
def log_box(self) -> str:
|
| 120 |
+
return "\n".join(self.logs[-60:]) # últimas 60 linhas
|
| 121 |
+
|
| 122 |
+
STATE = TrainingState()
|
| 123 |
+
|
| 124 |
+
# ──────────────────────────────────────────────────────────────
|
| 125 |
+
# CALLBACK PARA PROGRESSO EM TEMPO REAL
|
| 126 |
+
# ──────────────────────────────────────────────────────────────
|
| 127 |
+
|
| 128 |
+
class ProgressCallback(TrainerCallback):
|
| 129 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 130 |
+
STATE.total_steps = state.max_steps
|
| 131 |
+
STATE.log(f"▶ Treinamento iniciado — {state.max_steps} steps")
|
| 132 |
+
|
| 133 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 134 |
+
if logs:
|
| 135 |
+
loss = logs.get("loss", "—")
|
| 136 |
+
lr = logs.get("learning_rate", "—")
|
| 137 |
+
step = state.global_step
|
| 138 |
+
STATE.progress = step
|
| 139 |
+
STATE.log(f"Step {step}/{STATE.total_steps} loss={loss} lr={lr}")
|
| 140 |
+
|
| 141 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 142 |
+
if STATE.cancelled:
|
| 143 |
+
control.should_training_stop = True
|
| 144 |
+
|
| 145 |
+
def on_train_end(self, args, state, control, **kwargs):
|
| 146 |
+
STATE.log("✅ Treinamento concluído!")
|
| 147 |
+
|
| 148 |
+
# ──────────────────────────────────────────────────────────────
|
| 149 |
+
# DATASET HELPERS
|
| 150 |
+
# ──────────────────────────────────────────────────────────────
|
| 151 |
+
|
| 152 |
+
def load_user_dataset(source: str, hf_dataset: str, uploaded_file) -> Dataset:
|
| 153 |
+
"""Carrega dataset de múltiplas fontes."""
|
| 154 |
+
if source == "HuggingFace Hub" and hf_dataset.strip():
|
| 155 |
+
ds = load_dataset(hf_dataset.strip(), split="train")
|
| 156 |
+
return ds
|
| 157 |
+
|
| 158 |
+
if source == "Upload CSV/JSONL" and uploaded_file is not None:
|
| 159 |
+
path = uploaded_file.name
|
| 160 |
+
if path.endswith(".csv"):
|
| 161 |
+
df = pd.read_csv(path)
|
| 162 |
+
else:
|
| 163 |
+
df = pd.read_json(path, lines=True)
|
| 164 |
+
return Dataset.from_pandas(df)
|
| 165 |
+
|
| 166 |
+
# Fallback: dataset de exemplo embutido
|
| 167 |
+
examples = [
|
| 168 |
+
{"text": "Instrução: Explique o que é machine learning.\nResposta: Machine learning é..."},
|
| 169 |
+
{"text": "Instrução: O que é uma rede neural?\nResposta: Uma rede neural é..."},
|
| 170 |
+
{"text": "Instrução: Como funciona o backpropagation?\nResposta: O backpropagation..."},
|
| 171 |
+
]
|
| 172 |
+
return Dataset.from_list(examples)
|
| 173 |
+
|
| 174 |
+
def tokenize_dataset(dataset: Dataset, tokenizer, max_length: int) -> Dataset:
|
| 175 |
+
text_col = next(
|
| 176 |
+
(c for c in ["text", "prompt", "instruction", "content"] if c in dataset.column_names),
|
| 177 |
+
dataset.column_names[0],
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def tokenize(examples):
|
| 181 |
+
return tokenizer(
|
| 182 |
+
examples[text_col],
|
| 183 |
+
truncation=True,
|
| 184 |
+
max_length=max_length,
|
| 185 |
+
padding="max_length",
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
|
| 189 |
+
|
| 190 |
+
# ──────────────────────────────────────────────────────────────
|
| 191 |
+
# CORE: CARREGA MODELO
|
| 192 |
+
# ──────────────────────────────────────────────────────────────
|
| 193 |
+
|
| 194 |
+
def load_model_and_tokenizer(model_name: str, ft_mode: str):
|
| 195 |
+
model_id, _ = MODEL_CATALOG[model_name]
|
| 196 |
+
STATE.log(f"⬇ Carregando tokenizer: {model_id}")
|
| 197 |
+
|
| 198 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
| 199 |
+
if tokenizer.pad_token is None:
|
| 200 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 201 |
+
|
| 202 |
+
# Configuração de quantização
|
| 203 |
+
bnb_cfg = None
|
| 204 |
+
load_in_4bit = False
|
| 205 |
+
load_in_8bit = False
|
| 206 |
+
|
| 207 |
+
if HW["device"] == "cuda" and ft_mode == "QLoRA":
|
| 208 |
+
if not PEFT_AVAILABLE:
|
| 209 |
+
raise RuntimeError("Instale `peft` e `bitsandbytes` para QLoRA.")
|
| 210 |
+
STATE.log("🔧 Configurando 4-bit NF4 (QLoRA)...")
|
| 211 |
+
bnb_cfg = BitsAndBytesConfig(
|
| 212 |
+
load_in_4bit=True,
|
| 213 |
+
bnb_4bit_use_double_quant=True,
|
| 214 |
+
bnb_4bit_quant_type="nf4",
|
| 215 |
+
bnb_4bit_compute_dtype=torch.bfloat16 if HW["bf16"] else torch.float16,
|
| 216 |
+
)
|
| 217 |
+
load_in_4bit = True
|
| 218 |
+
|
| 219 |
+
STATE.log(f"⬇ Carregando modelo ({ft_mode})...")
|
| 220 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 221 |
+
model_id,
|
| 222 |
+
quantization_config=bnb_cfg,
|
| 223 |
+
device_map="auto" if HW["device"] == "cuda" else None,
|
| 224 |
+
torch_dtype=torch.float16 if (HW["device"] == "cuda" and not HW["bf16"]) else "auto",
|
| 225 |
+
trust_remote_code=True,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if ft_mode in ("LoRA", "QLoRA") and PEFT_AVAILABLE:
|
| 229 |
+
if load_in_4bit:
|
| 230 |
+
model = prepare_model_for_kbit_training(model)
|
| 231 |
+
|
| 232 |
+
lora_cfg = LoraConfig(
|
| 233 |
+
r=16,
|
| 234 |
+
lora_alpha=32,
|
| 235 |
+
lora_dropout=0.05,
|
| 236 |
+
bias="none",
|
| 237 |
+
task_type=TaskType.CAUSAL_LM,
|
| 238 |
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
| 239 |
+
)
|
| 240 |
+
model = get_peft_model(model, lora_cfg)
|
| 241 |
+
model.print_trainable_parameters()
|
| 242 |
+
trainable, total, pct = model.get_nb_trainable_parameters()
|
| 243 |
+
STATE.log(f"📊 Parâmetros treináveis: {trainable:,} / {total:,} ({pct:.2f}%)")
|
| 244 |
+
|
| 245 |
+
elif ft_mode == "Full Fine-Tuning":
|
| 246 |
+
STATE.log("⚠ Full fine-tuning: todos os pesos serão atualizados.")
|
| 247 |
+
if HW["device"] != "cuda":
|
| 248 |
+
STATE.log("⚠ Full fine-tuning em CPU será MUITO lento.")
|
| 249 |
+
|
| 250 |
+
return model, tokenizer
|
| 251 |
+
|
| 252 |
+
# ──────────────────────────────────────────────────────────────
|
| 253 |
+
# CORE: TREINAMENTO
|
| 254 |
+
# ──────────────────────────────────────────────────────────────
|
| 255 |
+
|
| 256 |
+
def run_training(
|
| 257 |
+
model_name, ft_mode,
|
| 258 |
+
dataset_source, hf_dataset, uploaded_file,
|
| 259 |
+
epochs, batch_size, learning_rate, max_length,
|
| 260 |
+
warmup_steps, weight_decay, grad_accum,
|
| 261 |
+
):
|
| 262 |
+
try:
|
| 263 |
+
STATE.reset()
|
| 264 |
+
STATE.running = True
|
| 265 |
+
STATE.output_dir = Path(f"./trained_{model_name.replace(' ', '_')}_{ft_mode}")
|
| 266 |
+
STATE.output_dir.mkdir(parents=True, exist_ok=True)
|
| 267 |
+
|
| 268 |
+
STATE.log(f"🖥 Hardware: {HW['gpu_name'] if HW['device']=='cuda' else 'CPU'}")
|
| 269 |
+
STATE.log(f"📦 Modelo: {model_name} | Modo: {ft_mode}")
|
| 270 |
+
|
| 271 |
+
# 1. Dataset
|
| 272 |
+
STATE.log("📂 Carregando dataset...")
|
| 273 |
+
raw_ds = load_user_dataset(dataset_source, hf_dataset, uploaded_file)
|
| 274 |
+
STATE.log(f"✅ Dataset: {len(raw_ds)} exemplos")
|
| 275 |
+
|
| 276 |
+
# 2. Modelo
|
| 277 |
+
model, tokenizer = load_model_and_tokenizer(model_name, ft_mode)
|
| 278 |
+
STATE.tokenizer = tokenizer
|
| 279 |
+
|
| 280 |
+
# 3. Tokenização
|
| 281 |
+
STATE.log("🔤 Tokenizando dataset...")
|
| 282 |
+
tokenized = tokenize_dataset(raw_ds, tokenizer, max_length)
|
| 283 |
+
tokenized = tokenized.train_test_split(test_size=0.05, seed=42)
|
| 284 |
+
|
| 285 |
+
# 4. TrainingArguments
|
| 286 |
+
use_fp16 = HW["device"] == "cuda" and not HW["bf16"]
|
| 287 |
+
use_bf16 = HW["device"] == "cuda" and HW["bf16"]
|
| 288 |
+
|
| 289 |
+
args = TrainingArguments(
|
| 290 |
+
output_dir=str(STATE.output_dir),
|
| 291 |
+
num_train_epochs=epochs,
|
| 292 |
+
per_device_train_batch_size=batch_size,
|
| 293 |
+
gradient_accumulation_steps=grad_accum,
|
| 294 |
+
learning_rate=learning_rate,
|
| 295 |
+
warmup_steps=warmup_steps,
|
| 296 |
+
weight_decay=weight_decay,
|
| 297 |
+
fp16=use_fp16,
|
| 298 |
+
bf16=use_bf16,
|
| 299 |
+
logging_steps=5,
|
| 300 |
+
save_steps=50,
|
| 301 |
+
save_total_limit=2,
|
| 302 |
+
eval_strategy="steps",
|
| 303 |
+
eval_steps=50,
|
| 304 |
+
load_best_model_at_end=True,
|
| 305 |
+
report_to="none",
|
| 306 |
+
dataloader_pin_memory=(HW["device"] == "cuda"),
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 310 |
+
|
| 311 |
+
trainer = Trainer(
|
| 312 |
+
model=model,
|
| 313 |
+
args=args,
|
| 314 |
+
train_dataset=tokenized["train"],
|
| 315 |
+
eval_dataset=tokenized["test"],
|
| 316 |
+
data_collator=collator,
|
| 317 |
+
callbacks=[ProgressCallback()],
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
STATE.log("🚀 Iniciando treinamento...")
|
| 321 |
+
trainer.train()
|
| 322 |
+
|
| 323 |
+
if not STATE.cancelled:
|
| 324 |
+
STATE.log("💾 Salvando modelo...")
|
| 325 |
+
model.save_pretrained(str(STATE.output_dir))
|
| 326 |
+
tokenizer.save_pretrained(str(STATE.output_dir))
|
| 327 |
+
|
| 328 |
+
# Salva metadados
|
| 329 |
+
meta = {
|
| 330 |
+
"base_model": MODEL_CATALOG[model_name][0],
|
| 331 |
+
"ft_mode": ft_mode,
|
| 332 |
+
"epochs": epochs,
|
| 333 |
+
"learning_rate": learning_rate,
|
| 334 |
+
"dataset_source": dataset_source,
|
| 335 |
+
"hardware": HW,
|
| 336 |
+
}
|
| 337 |
+
(STATE.output_dir / "training_meta.json").write_text(json.dumps(meta, indent=2))
|
| 338 |
+
|
| 339 |
+
STATE.model = model
|
| 340 |
+
STATE.log(f"🎉 Modelo salvo em: {STATE.output_dir}")
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
STATE.error = str(e)
|
| 344 |
+
STATE.log(f"❌ Erro: {e}")
|
| 345 |
+
logger.exception(e)
|
| 346 |
+
finally:
|
| 347 |
+
STATE.running = False
|
| 348 |
+
|
| 349 |
+
# ──────────────────────────────────────────────────────────────
|
| 350 |
+
# CHAT COM MODELO TREINADO
|
| 351 |
+
# ──────────────────────────────────────────────────────────────
|
| 352 |
+
|
| 353 |
+
def chat_with_model(message: str, history: list, max_new_tokens: int, temperature: float):
|
| 354 |
+
if STATE.model is None or STATE.tokenizer is None:
|
| 355 |
+
return history + [[message, "⚠ Nenhum modelo treinado disponível. Complete o treinamento primeiro."]]
|
| 356 |
+
|
| 357 |
+
prompt = ""
|
| 358 |
+
for user_msg, bot_msg in history:
|
| 359 |
+
prompt += f"Usuário: {user_msg}\nAssistente: {bot_msg}\n"
|
| 360 |
+
prompt += f"Usuário: {message}\nAssistente:"
|
| 361 |
+
|
| 362 |
+
inputs = STATE.tokenizer(prompt, return_tensors="pt")
|
| 363 |
+
if HW["device"] == "cuda":
|
| 364 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 365 |
+
|
| 366 |
+
with torch.no_grad():
|
| 367 |
+
outputs = STATE.model.generate(
|
| 368 |
+
**inputs,
|
| 369 |
+
max_new_tokens=max_new_tokens,
|
| 370 |
+
temperature=temperature,
|
| 371 |
+
do_sample=temperature > 0,
|
| 372 |
+
pad_token_id=STATE.tokenizer.eos_token_id,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
gen_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 376 |
+
response = STATE.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
|
| 377 |
+
return history + [[message, response]]
|
| 378 |
+
|
| 379 |
+
# ──────────────────────────────────────────────────────────────
|
| 380 |
+
# ZIP & DOWNLOAD
|
| 381 |
+
# ──────────────────────────────────────────────────────────────
|
| 382 |
+
|
| 383 |
+
def create_download_zip() -> Optional[str]:
|
| 384 |
+
if not STATE.output_dir.exists():
|
| 385 |
+
return None
|
| 386 |
+
zip_path = Path("./model_export.zip")
|
| 387 |
+
shutil.make_archive("model_export", "zip", str(STATE.output_dir))
|
| 388 |
+
return str(zip_path)
|
| 389 |
+
|
| 390 |
+
# ──────────────────────────────────────────────────────────────
|
| 391 |
+
# GRADIO UI
|
| 392 |
+
# ──────────────────────────────────────────────────────────────
|
| 393 |
+
|
| 394 |
+
CSS = """
|
| 395 |
+
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=Syne:wght@400;600;800&display=swap');
|
| 396 |
+
|
| 397 |
+
* { box-sizing: border-box; }
|
| 398 |
+
|
| 399 |
+
body, .gradio-container {
|
| 400 |
+
background: #0a0a0f !important;
|
| 401 |
+
color: #e8e6f0 !important;
|
| 402 |
+
font-family: 'Syne', sans-serif !important;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
.gr-panel, .gr-box, .gr-block { background: transparent !important; }
|
| 406 |
+
|
| 407 |
+
/* Header */
|
| 408 |
+
.studio-header {
|
| 409 |
+
text-align: center;
|
| 410 |
+
padding: 2rem 0 1rem;
|
| 411 |
+
border-bottom: 1px solid #2a2a3a;
|
| 412 |
+
margin-bottom: 1.5rem;
|
| 413 |
+
}
|
| 414 |
+
.studio-header h1 {
|
| 415 |
+
font-family: 'Syne', sans-serif;
|
| 416 |
+
font-weight: 800;
|
| 417 |
+
font-size: 2.2rem;
|
| 418 |
+
letter-spacing: -0.02em;
|
| 419 |
+
color: #fff;
|
| 420 |
+
margin: 0;
|
| 421 |
+
}
|
| 422 |
+
.studio-header h1 span { color: #7c6af7; }
|
| 423 |
+
.studio-header p {
|
| 424 |
+
font-family: 'Space Mono', monospace;
|
| 425 |
+
font-size: 0.75rem;
|
| 426 |
+
color: #6b6888;
|
| 427 |
+
margin-top: 0.4rem;
|
| 428 |
+
letter-spacing: 0.08em;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/* Hardware badge */
|
| 432 |
+
.hw-badge {
|
| 433 |
+
font-family: 'Space Mono', monospace;
|
| 434 |
+
font-size: 0.72rem;
|
| 435 |
+
background: #12121e;
|
| 436 |
+
border: 1px solid #2a2a3a;
|
| 437 |
+
border-radius: 6px;
|
| 438 |
+
padding: 0.5rem 1rem;
|
| 439 |
+
color: #8a88a8;
|
| 440 |
+
text-align: center;
|
| 441 |
+
margin-bottom: 1.2rem;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
/* Tabs */
|
| 445 |
+
.tab-nav button {
|
| 446 |
+
font-family: 'Syne', sans-serif !important;
|
| 447 |
+
font-weight: 600 !important;
|
| 448 |
+
font-size: 0.85rem !important;
|
| 449 |
+
letter-spacing: 0.04em !important;
|
| 450 |
+
color: #6b6888 !important;
|
| 451 |
+
background: transparent !important;
|
| 452 |
+
border: none !important;
|
| 453 |
+
border-bottom: 2px solid transparent !important;
|
| 454 |
+
padding: 0.5rem 1.2rem !important;
|
| 455 |
+
}
|
| 456 |
+
.tab-nav button.selected {
|
| 457 |
+
color: #7c6af7 !important;
|
| 458 |
+
border-bottom-color: #7c6af7 !important;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
/* Inputs */
|
| 462 |
+
.gr-input, .gr-dropdown select, textarea {
|
| 463 |
+
background: #12121e !important;
|
| 464 |
+
border: 1px solid #2a2a3a !important;
|
| 465 |
+
color: #e8e6f0 !important;
|
| 466 |
+
border-radius: 8px !important;
|
| 467 |
+
font-family: 'Space Mono', monospace !important;
|
| 468 |
+
font-size: 0.8rem !important;
|
| 469 |
+
}
|
| 470 |
+
.gr-input:focus, textarea:focus {
|
| 471 |
+
border-color: #7c6af7 !important;
|
| 472 |
+
box-shadow: 0 0 0 2px rgba(124,106,247,0.15) !important;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
/* Buttons */
|
| 476 |
+
.gr-button {
|
| 477 |
+
font-family: 'Syne', sans-serif !important;
|
| 478 |
+
font-weight: 600 !important;
|
| 479 |
+
border-radius: 8px !important;
|
| 480 |
+
transition: all 0.15s !important;
|
| 481 |
+
}
|
| 482 |
+
.gr-button.primary {
|
| 483 |
+
background: #7c6af7 !important;
|
| 484 |
+
border: none !important;
|
| 485 |
+
color: #fff !important;
|
| 486 |
+
}
|
| 487 |
+
.gr-button.primary:hover { background: #6a58e0 !important; transform: translateY(-1px); }
|
| 488 |
+
.gr-button.secondary {
|
| 489 |
+
background: transparent !important;
|
| 490 |
+
border: 1px solid #2a2a3a !important;
|
| 491 |
+
color: #8a88a8 !important;
|
| 492 |
+
}
|
| 493 |
+
.gr-button.stop { background: #c0392b !important; color: #fff !important; border: none !important; }
|
| 494 |
+
|
| 495 |
+
/* Log box */
|
| 496 |
+
.log-box textarea {
|
| 497 |
+
font-family: 'Space Mono', monospace !important;
|
| 498 |
+
font-size: 0.72rem !important;
|
| 499 |
+
line-height: 1.6 !important;
|
| 500 |
+
background: #07070f !important;
|
| 501 |
+
border: 1px solid #1e1e2e !important;
|
| 502 |
+
color: #a8e6cf !important;
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
/* Progress bar */
|
| 506 |
+
.progress-bar-wrap .progress-bar { background: #7c6af7 !important; }
|
| 507 |
+
|
| 508 |
+
/* Slider labels */
|
| 509 |
+
.gr-form label {
|
| 510 |
+
font-family: 'Syne', sans-serif !important;
|
| 511 |
+
font-size: 0.82rem !important;
|
| 512 |
+
color: #8a88a8 !important;
|
| 513 |
+
font-weight: 600 !important;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
/* Section labels */
|
| 517 |
+
.section-label {
|
| 518 |
+
font-family: 'Space Mono', monospace;
|
| 519 |
+
font-size: 0.65rem;
|
| 520 |
+
letter-spacing: 0.12em;
|
| 521 |
+
color: #4a4868;
|
| 522 |
+
text-transform: uppercase;
|
| 523 |
+
margin: 1rem 0 0.4rem;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
/* Chat bubbles */
|
| 527 |
+
.message.user div { background: #1e1e30 !important; border-radius: 10px !important; }
|
| 528 |
+
.message.bot div { background: #12121e !important; border-radius: 10px !important; border: 1px solid #2a2a3a !important; }
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
def build_ui():
|
| 532 |
+
models = available_models()
|
| 533 |
+
if not models:
|
| 534 |
+
models = list(MODEL_CATALOG.keys())[:2] # fallback
|
| 535 |
+
|
| 536 |
+
with gr.Blocks(css=CSS, title="Fine-Tuning Studio") as demo:
|
| 537 |
+
|
| 538 |
+
# ── HEADER ────────────────────────────────────────────
|
| 539 |
+
gr.HTML(f"""
|
| 540 |
+
<div class="studio-header">
|
| 541 |
+
<h1>🧪 Fine-Tuning <span>Studio</span></h1>
|
| 542 |
+
<p>TREINE · CONVERTA · CONVERSE · EXPORTE</p>
|
| 543 |
+
</div>
|
| 544 |
+
<div class="hw-badge">{hw_banner()}</div>
|
| 545 |
+
""")
|
| 546 |
+
|
| 547 |
+
# ── TABS ──────────────────────────────────────────────
|
| 548 |
+
with gr.Tabs(elem_classes="tab-nav"):
|
| 549 |
+
|
| 550 |
+
# ════════════════════════════════
|
| 551 |
+
# TAB 1 — CONFIGURAR & TREINAR
|
| 552 |
+
# ════════════════════════════════
|
| 553 |
+
with gr.Tab("⚙️ Treinar"):
|
| 554 |
+
with gr.Row():
|
| 555 |
+
|
| 556 |
+
# Coluna esquerda — config
|
| 557 |
+
with gr.Column(scale=1):
|
| 558 |
+
gr.HTML('<div class="section-label">modelo</div>')
|
| 559 |
+
model_dd = gr.Dropdown(
|
| 560 |
+
choices=models,
|
| 561 |
+
value=models[0],
|
| 562 |
+
label="Modelo base",
|
| 563 |
+
interactive=True,
|
| 564 |
+
)
|
| 565 |
+
ft_mode_dd = gr.Dropdown(
|
| 566 |
+
choices=["LoRA", "QLoRA", "Full Fine-Tuning"],
|
| 567 |
+
value="LoRA" if PEFT_AVAILABLE else "Full Fine-Tuning",
|
| 568 |
+
label="Modo de fine-tuning",
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
gr.HTML('<div class="section-label">dataset</div>')
|
| 572 |
+
ds_source = gr.Radio(
|
| 573 |
+
choices=["HuggingFace Hub", "Upload CSV/JSONL", "Exemplo embutido"],
|
| 574 |
+
value="Exemplo embutido",
|
| 575 |
+
label="Fonte do dataset",
|
| 576 |
+
)
|
| 577 |
+
hf_ds_input = gr.Textbox(
|
| 578 |
+
placeholder="ex: tatsu-lab/alpaca",
|
| 579 |
+
label="Dataset ID (Hub)",
|
| 580 |
+
visible=False,
|
| 581 |
+
)
|
| 582 |
+
upload_file = gr.File(
|
| 583 |
+
label="CSV ou JSONL",
|
| 584 |
+
file_types=[".csv", ".jsonl"],
|
| 585 |
+
visible=False,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def toggle_ds(source):
|
| 589 |
+
return (
|
| 590 |
+
gr.update(visible=source == "HuggingFace Hub"),
|
| 591 |
+
gr.update(visible=source == "Upload CSV/JSONL"),
|
| 592 |
+
)
|
| 593 |
+
ds_source.change(toggle_ds, ds_source, [hf_ds_input, upload_file])
|
| 594 |
+
|
| 595 |
+
gr.HTML('<div class="section-label">hiperparâmetros</div>')
|
| 596 |
+
epochs_sl = gr.Slider(1, 10, value=3, step=1, label="Épocas")
|
| 597 |
+
batch_sl = gr.Slider(1, 16, value=2, step=1, label="Batch size")
|
| 598 |
+
lr_sl = gr.Slider(1e-5, 5e-4, value=2e-4, step=1e-5, label="Learning rate")
|
| 599 |
+
max_len_sl = gr.Slider(64, 2048, value=512, step=64, label="Max length (tokens)")
|
| 600 |
+
grad_acc_sl = gr.Slider(1, 16, value=4, step=1, label="Grad. accumulation")
|
| 601 |
+
warmup_sl = gr.Slider(0, 200, value=10, step=5, label="Warmup steps")
|
| 602 |
+
wd_sl = gr.Slider(0, 0.1, value=0.01, step=0.005, label="Weight decay")
|
| 603 |
+
|
| 604 |
+
# Coluna direita — logs
|
| 605 |
+
with gr.Column(scale=1):
|
| 606 |
+
gr.HTML('<div class="section-label">log de treinamento</div>')
|
| 607 |
+
log_box = gr.Textbox(
|
| 608 |
+
label="",
|
| 609 |
+
lines=24,
|
| 610 |
+
max_lines=24,
|
| 611 |
+
interactive=False,
|
| 612 |
+
elem_classes="log-box",
|
| 613 |
+
placeholder="O log aparecerá aqui quando o treinamento iniciar...",
|
| 614 |
+
)
|
| 615 |
+
progress = gr.Slider(
|
| 616 |
+
0, 100, value=0, label="Progresso (%)", interactive=False
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
with gr.Row():
|
| 620 |
+
train_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary", scale=3)
|
| 621 |
+
cancel_btn = gr.Button("⏹ Cancelar", variant="stop", scale=1)
|
| 622 |
+
|
| 623 |
+
status_md = gr.Markdown("")
|
| 624 |
+
|
| 625 |
+
# ── Lógica de treinamento ──────────────────
|
| 626 |
+
def start_training(
|
| 627 |
+
model_name, ft_mode,
|
| 628 |
+
ds_source, hf_ds, up_file,
|
| 629 |
+
epochs, batch, lr, max_len,
|
| 630 |
+
warmup, wd, grad_acc,
|
| 631 |
+
):
|
| 632 |
+
if STATE.running:
|
| 633 |
+
return "⚠ Treinamento já em andamento."
|
| 634 |
+
|
| 635 |
+
thread = threading.Thread(
|
| 636 |
+
target=run_training,
|
| 637 |
+
args=(
|
| 638 |
+
model_name, ft_mode,
|
| 639 |
+
ds_source, hf_ds, up_file,
|
| 640 |
+
epochs, batch, lr, max_len,
|
| 641 |
+
warmup, wd, grad_acc,
|
| 642 |
+
),
|
| 643 |
+
daemon=True,
|
| 644 |
+
)
|
| 645 |
+
thread.start()
|
| 646 |
+
return "▶ Treinamento iniciado..."
|
| 647 |
+
|
| 648 |
+
def cancel_training():
|
| 649 |
+
STATE.cancelled = True
|
| 650 |
+
return "⏹ Cancelamento solicitado."
|
| 651 |
+
|
| 652 |
+
def poll_logs():
|
| 653 |
+
"""Polling a cada 2s para atualizar log e progresso."""
|
| 654 |
+
while True:
|
| 655 |
+
pct = 0
|
| 656 |
+
if STATE.total_steps > 0:
|
| 657 |
+
pct = min(100, int(STATE.progress / STATE.total_steps * 100))
|
| 658 |
+
yield STATE.log_box(), pct
|
| 659 |
+
time.sleep(2)
|
| 660 |
+
|
| 661 |
+
train_btn.click(
|
| 662 |
+
start_training,
|
| 663 |
+
inputs=[
|
| 664 |
+
model_dd, ft_mode_dd,
|
| 665 |
+
ds_source, hf_ds_input, upload_file,
|
| 666 |
+
epochs_sl, batch_sl, lr_sl, max_len_sl,
|
| 667 |
+
warmup_sl, wd_sl, grad_acc_sl,
|
| 668 |
+
],
|
| 669 |
+
outputs=status_md,
|
| 670 |
+
)
|
| 671 |
+
cancel_btn.click(cancel_training, outputs=status_md)
|
| 672 |
+
|
| 673 |
+
demo.load(poll_logs, outputs=[log_box, progress], every=2)
|
| 674 |
+
|
| 675 |
+
# ════════════════════════════════
|
| 676 |
+
# TAB 2 — CHAT
|
| 677 |
+
# ════════════════════════════════
|
| 678 |
+
with gr.Tab("💬 Chat"):
|
| 679 |
+
gr.Markdown(
|
| 680 |
+
"**Converse com o modelo treinado.** Complete o treinamento na aba anterior primeiro.",
|
| 681 |
+
elem_id="chat-hint",
|
| 682 |
+
)
|
| 683 |
+
chatbot = gr.Chatbot(height=440, label="Conversa")
|
| 684 |
+
|
| 685 |
+
with gr.Row():
|
| 686 |
+
chat_input = gr.Textbox(
|
| 687 |
+
placeholder="Digite sua mensagem...",
|
| 688 |
+
label="",
|
| 689 |
+
scale=4,
|
| 690 |
+
)
|
| 691 |
+
send_btn = gr.Button("Enviar", variant="primary", scale=1)
|
| 692 |
+
|
| 693 |
+
with gr.Accordion("⚙️ Parâmetros de geração", open=False):
|
| 694 |
+
max_new_sl = gr.Slider(32, 1024, value=256, step=32, label="Max new tokens")
|
| 695 |
+
temp_sl = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
|
| 696 |
+
|
| 697 |
+
clear_btn = gr.Button("🗑 Limpar conversa", variant="secondary")
|
| 698 |
+
|
| 699 |
+
send_btn.click(
|
| 700 |
+
chat_with_model,
|
| 701 |
+
inputs=[chat_input, chatbot, max_new_sl, temp_sl],
|
| 702 |
+
outputs=chatbot,
|
| 703 |
+
)
|
| 704 |
+
chat_input.submit(
|
| 705 |
+
chat_with_model,
|
| 706 |
+
inputs=[chat_input, chatbot, max_new_sl, temp_sl],
|
| 707 |
+
outputs=chatbot,
|
| 708 |
+
)
|
| 709 |
+
clear_btn.click(lambda: [], outputs=chatbot)
|
| 710 |
+
|
| 711 |
+
# ════════════════════════════════
|
| 712 |
+
# TAB 3 — EXPORTAR
|
| 713 |
+
# ════════════════════════════════
|
| 714 |
+
with gr.Tab("📦 Exportar"):
|
| 715 |
+
gr.Markdown("### Download dos pesos treinados")
|
| 716 |
+
gr.Markdown(
|
| 717 |
+
"Após o treinamento, clique abaixo para gerar um `.zip` com todos os pesos e metadados."
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
with gr.Row():
|
| 721 |
+
zip_btn = gr.Button("📦 Gerar ZIP", variant="primary")
|
| 722 |
+
download_out = gr.File(label="Download", interactive=False)
|
| 723 |
+
|
| 724 |
+
export_status = gr.Markdown("")
|
| 725 |
+
|
| 726 |
+
def generate_zip():
|
| 727 |
+
path = create_download_zip()
|
| 728 |
+
if path:
|
| 729 |
+
return path, "✅ ZIP gerado! Clique para baixar."
|
| 730 |
+
return None, "⚠ Nenhum modelo treinado encontrado. Complete o treinamento primeiro."
|
| 731 |
+
|
| 732 |
+
zip_btn.click(generate_zip, outputs=[download_out, export_status])
|
| 733 |
+
|
| 734 |
+
gr.Markdown("---")
|
| 735 |
+
gr.Markdown("### Push para HuggingFace Hub")
|
| 736 |
+
gr.Markdown(
|
| 737 |
+
"Para fazer push do modelo para o Hub, configure o `HF_TOKEN` nas **Secrets** do Space "
|
| 738 |
+
"e use `model.push_to_hub('seu-usuario/nome-do-modelo')` no terminal."
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
# Metadados do treino
|
| 742 |
+
gr.HTML('<div class="section-label">metadados do treino</div>')
|
| 743 |
+
|
| 744 |
+
def get_meta():
|
| 745 |
+
meta_file = STATE.output_dir / "training_meta.json"
|
| 746 |
+
if meta_file.exists():
|
| 747 |
+
return meta_file.read_text()
|
| 748 |
+
return "Sem metadados ainda."
|
| 749 |
+
|
| 750 |
+
meta_box = gr.Code(label="training_meta.json", language="json", interactive=False)
|
| 751 |
+
refresh_meta_btn = gr.Button("🔄 Atualizar metadados", variant="secondary")
|
| 752 |
+
refresh_meta_btn.click(get_meta, outputs=meta_box)
|
| 753 |
+
|
| 754 |
+
# Footer
|
| 755 |
+
gr.HTML("""
|
| 756 |
+
<div style="text-align:center; margin-top:2rem; font-family:'Space Mono',monospace;
|
| 757 |
+
font-size:0.65rem; color:#3a3858; letter-spacing:0.1em;">
|
| 758 |
+
FINE-TUNING STUDIO · HUGGINGFACE SPACE · ADAPTA-SE AO HARDWARE DISPONÍVEL
|
| 759 |
+
</div>
|
| 760 |
+
""")
|
| 761 |
+
|
| 762 |
+
return demo
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
# ──────────────────────────────────────────────────────────────
|
| 766 |
+
# ENTRY POINT
|
| 767 |
+
# ──────────────────────────────────────────────────────────────
|
| 768 |
+
|
| 769 |
+
if __name__ == "__main__":
|
| 770 |
+
app = build_ui()
|
| 771 |
+
app.launch(
|
| 772 |
+
server_name="0.0.0.0",
|
| 773 |
+
server_port=7860,
|
| 774 |
+
share=False,
|
| 775 |
+
show_error=True,
|
| 776 |
+
)
|