File size: 6,371 Bytes
9275790
4cafbc8
 
 
 
09dea70
81d438f
4cafbc8
 
594c14a
 
 
 
09dea70
1ab5ebf
 
594c14a
 
 
 
81d438f
594c14a
09dea70
1ab5ebf
594c14a
09dea70
81d438f
594c14a
09dea70
594c14a
09dea70
1ab5ebf
 
 
 
 
 
 
 
 
 
 
bba7e77
09dea70
1ab5ebf
 
 
 
09dea70
bba7e77
81d438f
1ab5ebf
09dea70
1ab5ebf
09dea70
81d438f
1ab5ebf
81d438f
1ab5ebf
 
 
81d438f
 
 
1ab5ebf
bba7e77
1ab5ebf
81d438f
 
 
 
 
 
09dea70
 
81d438f
594c14a
81d438f
 
 
 
594c14a
09dea70
81d438f
09dea70
1ab5ebf
09dea70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d438f
09dea70
81d438f
1ab5ebf
09dea70
4cafbc8
9275790
 
4cafbc8
9275790
 
 
4cafbc8
9275790
 
4cafbc8
9275790
09dea70
 
9275790
 
4cafbc8
9275790
594c14a
81d438f
594c14a
1ab5ebf
594c14a
 
 
9275790
81d438f
09dea70
9275790
81d438f
9275790
81d438f
9275790
 
81d438f
9275790
 
09dea70
 
 
 
 
 
 
 
 
1ab5ebf
594c14a
 
 
81d438f
594c14a
 
1ab5ebf
81d438f
594c14a
 
 
 
81d438f
 
594c14a
09dea70
9275790
81d438f
9275790
4cafbc8
9275790
4cafbc8
 
 
9275790
4cafbc8
9275790
 
81d438f
9275790
 
81d438f
9275790
 
 
594c14a
 
 
 
4cafbc8
 
 
9275790
4cafbc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# app.py
import os
import gradio as gr
from preprocess import process_dataset
import subprocess
import zipfile
import shutil
import time

def train_lora_interface(
    dataset_input, input_type, model_name, lora_rank, learning_rate,
    num_epochs, hub_token, concept_name, description
):
    if not dataset_input:
        yield "❌ Por favor, envie um ZIP ou selecione imagens."
        return
    if not concept_name.strip():
        yield "❌ Por favor, defina um nome para o conceito (ex: brenda)."
        return
    if not description.strip():
        yield "❌ Por favor, adicione uma descrição base."
        return

    concept_name = concept_name.strip().replace(" ", "_")
    full_description = f"{description.strip()}, {concept_name}"

    yield f"🏷️ Treinando: '{concept_name}'"

    dataset_dir = "processed_data"
    os.makedirs(dataset_dir, exist_ok=True)

    # Limpa pasta anterior
    for item in os.listdir(dataset_dir):
        item_path = os.path.join(dataset_dir, item)
        try:
            if os.path.isfile(item_path) or os.path.islink(item_path):
                os.unlink(item_path)
            elif os.path.isdir(item_path):
                shutil.rmtree(item_path)
        except Exception as e:
            yield f"⚠️ Erro ao limpar: {e}"

    # --- ETAPA 1: Processar entrada ---
    if input_type == "Upload de ZIP":
        zip_file = dataset_input[0] if isinstance(dataset_input, list) else dataset_input
        
        if not zipfile.is_zipfile(zip_file):
            yield "❌ Arquivo não é um ZIP válido."
            return

        yield "📦 Descompactando..."
        with zipfile.ZipFile(zip_file, 'r') as z:
            z.extractall(dataset_dir)
        yield f"✅ ZIP extraído! {len(z.namelist())} arquivos."

    else:
        image_files = dataset_input if isinstance(dataset_input, list) else [dataset_input]
        yield f"🖼️ Copiando {len(image_files)} imagens..."

        for uploaded_file in image_files:
            if hasattr(uploaded_file, 'name'):
                src = uploaded_file.name
                dest = os.path.join(dataset_dir, os.path.basename(src))
                shutil.copy(src, dest)
        
        yield f"✅ {len(image_files)} imagens copiadas."

    # --- ETAPA 2: Gera legendas ---
    exts = ('.png', '.jpg', '.jpeg', '.bmp', '.webp')
    images = [f for f in os.listdir(dataset_dir) if f.lower().endswith(exts)]
    
    if len(images) == 0:
        yield "❌ Nenhuma imagem encontrada!"
        return

    yield f"📝 Aplicando legenda: '{full_description}'"

    for img in images:
        txt = os.path.join(dataset_dir, os.path.splitext(img)[0] + ".txt")
        if not os.path.exists(txt):
            with open(txt, "w", encoding="utf-8") as f:
                f.write(full_description)

    yield "🔍 Legendas prontas!"

    # --- ETAPA 3: Treinamento ---
    output_dir = "lora-output"
    os.makedirs(output_dir, exist_ok=True)

    cmd = [
        "python", "train_lora.py",
        "--dataset_dir", dataset_dir,
        "--model_name", model_name,
        "--lora_rank", str(lora_rank),
        "--learning_rate", str(learning_rate),
        "--num_epochs", str(num_epochs),
        "--batch_size", "1",
        "--output_dir", output_dir
    ]

    if hub_token:
        os.environ["HF_TOKEN"] = hub_token
        cmd += ["--push_to_hub", "--hub_model_id", f"{concept_name}-lora"]

    yield "🔥 Iniciando treinamento..."

    try:
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True,
            bufsize=1,
            encoding='utf-8'
        )

        log_output = ""
        for line in process.stdout:
            log_output += line
            if "loss" in line.lower() or "epoch" in line.lower():
                yield f"📊 {line.strip()}"

        process.wait()

        if process.returncode == 0:
            yield f"""
🎉 SUCESSO!

🔹 Use no prompt: `photo of {concept_name} in the forest`
🔹 Modelo salvo em: `{output_dir}`
{'🔹 Publicado no Hub!' if hub_token else ''}
            """
        else:
            yield f"❌ Falha no treinamento. Código: {process.returncode}\nLogs:\n{log_output[-1000:]}"

    except Exception as e:
        yield f"💥 Erro: {str(e)}"

# --- Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎨 Treinador de LoRA - Hugging Face")
    gr.Markdown("Treine personagens, estilos ou objetos personalizados.")

    with gr.Row():
        input_type = gr.Radio(
            ["Upload de ZIP", "Selecionar várias imagens"],
            label="Tipo de Entrada",
            value="Upload de ZIP"
        )

    with gr.Row():
        dataset_input = gr.File(
            label="📤 Envie seu ZIP ou imagens",
            file_types=[".zip", ".jpg", ".jpeg", ".png", ".bmp", ".webp"],
            file_count="multiple"
        )

    gr.Markdown("### 🔖 Identidade do Personagem")
    with gr.Row():
        concept_name = gr.Textbox(
            label="Nome do Conceito (ex: brenda)",
            placeholder="Ex: brenda, cyborg_x",
            value=""
        )
    with gr.Row():
        description = gr.Textbox(
            label="Descrição Base (ex: woman, curly hair)",
            placeholder="Ex: young black woman, realistic style",
            lines=2
        )

    gr.Markdown("### ⚙️ Configurações")
    with gr.Row():
        model_name = gr.Dropdown(
            ["runwayml/stable-diffusion-v1-5"],
            value="runwayml/stable-diffusion-v1-5",
            label="Modelo Base"
        )
        lora_rank = gr.Slider(4, 64, value=4, step=4, label="LoRA Rank")
        learning_rate = gr.Number(value=1e-4, label="Taxa de Aprendizado")
        num_epochs = gr.Slider(1, 30, value=10, step=1, label="Épocas")

    hub_token = gr.Textbox(label="🔐 Token do HF (opcional)", type="password")

    btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
    output = gr.Textbox(label="📦 Logs", lines=12)

    btn.click(
        train_lora_interface,
        inputs=[
            dataset_input, input_type, model_name, lora_rank,
            learning_rate, num_epochs, hub_token, concept_name, description
        ],
        outputs=output
    )

demo.queue()
if __name__ == "__main__":
    demo.launch()