File size: 7,936 Bytes
9275790
bba7e77
 
 
4cafbc8
bba7e77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cafbc8
 
 
09dea70
4cafbc8
 
594c14a
 
 
 
09dea70
1ab5ebf
 
594c14a
 
 
 
 
 
09dea70
1ab5ebf
594c14a
09dea70
bba7e77
594c14a
1ab5ebf
09dea70
594c14a
09dea70
1ab5ebf
 
 
 
 
 
 
 
 
 
 
bba7e77
09dea70
1ab5ebf
 
 
 
09dea70
bba7e77
1ab5ebf
 
09dea70
1ab5ebf
09dea70
 
1ab5ebf
 
 
 
 
 
 
 
bba7e77
1ab5ebf
 
 
bba7e77
1ab5ebf
bba7e77
1ab5ebf
 
 
09dea70
bba7e77
09dea70
 
bba7e77
594c14a
 
 
1ab5ebf
594c14a
 
09dea70
1ab5ebf
09dea70
1ab5ebf
09dea70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594c14a
09dea70
bba7e77
1ab5ebf
09dea70
4cafbc8
9275790
 
4cafbc8
9275790
 
 
4cafbc8
9275790
 
4cafbc8
9275790
09dea70
 
9275790
 
4cafbc8
9275790
594c14a
1ab5ebf
594c14a
1ab5ebf
594c14a
 
 
9275790
bba7e77
09dea70
9275790
bba7e77
9275790
1ab5ebf
9275790
 
bba7e77
9275790
 
09dea70
 
 
 
 
 
 
 
 
1ab5ebf
594c14a
 
 
 
 
 
1ab5ebf
bba7e77
594c14a
 
 
 
bba7e77
 
594c14a
09dea70
9275790
09dea70
9275790
4cafbc8
9275790
4cafbc8
 
 
9275790
4cafbc8
9275790
 
 
 
 
594c14a
9275790
 
 
594c14a
 
 
 
4cafbc8
 
 
1ab5ebf
9275790
1ab5ebf
4cafbc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# app.py
# ⚠️ NÃO REMOVA: Força instalação de pacotes críticos antes de qualquer import
import subprocess
import sys
import os
import shutil

def install_packages():
    print("🔧 Forçando reinstalação de diffusers e huggingface_hub...")
    try:
        subprocess.check_call([
            sys.executable, "-m", "pip", "install",
            "--force-reinstall",
            "diffusers>=0.26.0",
            "huggingface-hub>=0.20.0",
            "accelerate",
            "peft",
            "torch==2.3.0",
            "transformers==4.40.0"
        ])
        print("✅ Pacotes essenciais instalados!")
    except Exception as e:
        print(f"❌ Falha na instalação forçada: {e}")

# Executa apenas uma vez por sessão
if not os.path.exists("/tmp/packages_installed"):
    install_packages()
    with open("/tmp/packages_installed", "w") as f:
        f.write("ok")

# Agora sim, imports seguros
import gradio as gr
from preprocess import process_dataset
import subprocess
import zipfile
import time

def train_lora_interface(
    dataset_input, input_type, model_name, lora_rank, learning_rate,
    num_epochs, hub_token, concept_name, description
):
    if not dataset_input:
        yield "❌ Por favor, envie um ZIP ou selecione imagens."
        return
    if not concept_name.strip():
        yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
        return
    if not description.strip():
        yield "❌ Por favor, adicione uma descrição base (ex: mulher, 30 anos, cabelo cacheado)."
        return

    concept_name = concept_name.strip().replace(" ", "_")
    full_description = f"{description.strip()}, {concept_name}"

    yield f"🏷️ Treinando conceito: '{concept_name}' → Prompt: [photo of {concept_name}]"

    # Pasta de trabalho
    dataset_dir = "processed_data"
    os.makedirs(dataset_dir, exist_ok=True)

    # Limpa pasta anterior
    for item in os.listdir(dataset_dir):
        item_path = os.path.join(dataset_dir, item)
        try:
            if os.path.isfile(item_path) or os.path.islink(item_path):
                os.unlink(item_path)
            elif os.path.isdir(item_path):
                shutil.rmtree(item_path)
        except Exception as e:
            yield f"⚠️ Erro ao limpar: {e}"

    # --- ETAPA 1: Processar entrada ---
    if input_type == "Upload de ZIP":
        zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input
        
        if not zipfile.is_zipfile(zip_file):
            yield "❌ Arquivo não é um ZIP válido."
            return

        yield "📦 Descompactando ZIP..."
        with zipfile.ZipFile(zip_file, 'r') as z:
            z.extractall(dataset_dir)
        yield f"✅ ZIP extraído! {len(z.namelist())} arquivos."

    else:  # Múltiplas imagens
        image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input]
        yield f"🖼️ Recebidas {len(image_files)} imagens. Copiando..."

        for uploaded_file in image_files:
            if hasattr(uploaded_file, 'name'):
                src_path = uploaded_file.name
                filename = os.path.basename(src_path)
                dest_path = os.path.join(dataset_dir, filename)
                shutil.copy(src_path, dest_path)  # Usa copy, não rename
            else:
                yield f"⚠️ Arquivo inválido: {uploaded_file}"
        
        yield f"✅ {len(image_files)} imagens copiadas."

    # --- ETAPA 2: Verifica imagens e gera legendas ---
    image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
    image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(image_extensions)]

    if len(image_files) == 0:
        yield "❌ Nenhuma imagem encontrada. Envie arquivos válidos."
        return

    yield f"📝 Aplicando legenda base: '{full_description}'"

    for img_name in image_files:
        txt_path = os.path.join(dataset_dir, os.path.splitext(img_name)[0] + ".txt")
        if not os.path.exists(txt_path):
            with open(txt_path, "w", encoding="utf-8") as f:
                f.write(full_description)

    yield "🔍 Legendas aplicadas com sucesso!"

    # --- ETAPA 3: Treinamento ---
    output_dir = "lora-output"
    os.makedirs(output_dir, exist_ok=True)

    cmd = [
        "python", "train_lora.py",
        "--dataset_dir", dataset_dir,
        "--model_name", model_name,
        "--lora_rank", str(lora_rank),
        "--learning_rate", str(learning_rate),
        "--num_epochs", str(num_epochs),
        "--batch_size", "1",
        "--output_dir", output_dir
    ]

    if hub_token:
        os.environ["HF_TOKEN"] = hub_token
        cmd.append("--push_to_hub")
        cmd.append("--hub_model_id")
        cmd.append(f"{concept_name}-lora")

    yield "🔥 Iniciando treinamento LoRA... Isso pode levar alguns minutos."

    try:
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True,
            bufsize=1,
            encoding='utf-8'
        )

        log_output = ""
        for line in process.stdout:
            log_output += line
            if "loss" in line.lower() or "epoch" in line.lower():
                yield f"📊 {line.strip()}"

        process.wait()

        if process.returncode == 0:
            yield f"""
🎉 TREINAMENTO CONCLUÍDO!

🔹 Use no prompt: `photo of {concept_name} in the forest`
🔹 Modelo salvo em: `{output_dir}`
{'🔹 Publicado no Hub!' if hub_token else ''}
            """
        else:
            yield f"❌ Treinamento falhou. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"

    except Exception as e:
        yield f"💥 Erro ao executar treinamento: {str(e)}"

# --- Interface Gradio ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
    gr.Markdown("Treine seu próprio modelo com nome, descrição e imagens!")

    with gr.Row():
        input_type = gr.Radio(
            ["Upload de ZIP", "Selecionar várias imagens"],
            label="Tipo de Entrada",
            value="Upload de ZIP"
        )

    with gr.Row():
        dataset_input = gr.File(
            label="📤 Envie seu ZIP ou imagens",
            file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"],
            file_count="multiple"
        )

    gr.Markdown("### 🔖 Identidade do Personagem/Conceito")
    with gr.Row():
        concept_name = gr.Textbox(
            label="Nome do Conceito (ex: brenda)",
            placeholder="Ex: brenda, cyborg_x, estilo_pintura",
            value=""
        )
    with gr.Row():
        description = gr.Textbox(
            label="Descrição Base (ex: woman, curly hair, realistic)",
            placeholder="Ex: young black woman, warm smile, detailed face",
            lines=2
        )

    gr.Markdown("### ⚙️ Configurações do Treinamento")
    with gr.Row():
        model_name = gr.Dropdown(
            ["runwayml/stable-diffusion-v1-5"],
            value="runwayml/stable-diffusion-v1-5",
            label="Modelo Base"
        )
        lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank")
        learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
        num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas")

    hub_token = gr.Textbox(label="🔐 Token do Hugging Face (opcional)", type="password")

    btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
    output = gr.Textbox(label="📦 Logs e Resultado", lines=12)

    btn.click(
        train_lora_interface,
        inputs=[
            dataset_input, input_type, model_name, lora_rank,
            learning_rate, num_epochs, hub_token, concept_name, description
        ],
        outputs=output
    )

# Ativa suporte a yield
demo.queue()

if __name__ == "__main__":
    demo.launch()