File size: 3,444 Bytes
d41cf19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import subprocess
import gradio as gr
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration

# ===== 1. Initialize BLIP-2 for Auto-Captioning =====
def load_blip_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", 
        torch_dtype=torch.float16
    ).to(device)
    return processor, model, device

processor, model, device = load_blip_model()

def generate_caption(image_path, trigger_word):
    image = Image.open(image_path)
    inputs = processor(image, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(**inputs, max_new_tokens=50)
    caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return f"a photo of [{trigger_word}], {caption}"

# ===== 2. Install Kohya_SS Manually =====
if not os.path.exists("kohya_ss"):
    print("⬇️ Installing Kohya_SS...")
    os.system("git clone https://github.com/bmaltais/kohya_ss")
    os.system("cd kohya_ss && pip install -r requirements.txt")
    os.system("cd kohya_ss && pip install .")

# ===== 3. Training Function =====
def train_lora(images, trigger_word, progress=gr.Progress()):
    progress(0.1, desc="Preparing data...")
    
    # Save images + auto-caption
    os.makedirs("train", exist_ok=True)
    for i, img in enumerate(progress.tqdm(images, desc="Processing images")):
        img_path = f"train/img_{i}.jpg"
        img.save(img_path)
        caption = generate_caption(img_path, trigger_word)
        with open(f"train/img_{i}.txt", "w") as f:
            f.write(caption)

    # Train LoRA (optimized for HF Spaces T4 GPU)
    cmd = """
    python kohya_ss/train_network.py \
        --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
        --train_data_dir="train" \
        --output_dir="output" \
        --resolution=512 \
        --network_dim=32 \
        --lr=1e-4 \
        --max_train_steps=800 \
        --mixed_precision="fp16" \
        --save_precision="fp16" \
        --optimizer_type="AdamW8bit" \
        --xformers
    """
    
    progress(0.8, desc="Training LoRA...")
    subprocess.run(cmd, shell=True, check=True)
    
    return "output/lora.safetensors"

# ===== 4. Gradio UI =====
with gr.Blocks(title="1-Click LoRA Trainer") as demo:
    gr.Markdown("""
    ## 🎨 Weights.gg-Style LoRA Trainer
    Upload 30 images + set a trigger word to train a custom LoRA.
    """)
    
    with gr.Row():
        with gr.Column():
            images = gr.Files(
                label="Upload Character Images (30 max)",
                file_types=["image"],
                interactive=True
            )
            trigger = gr.Textbox(
                label="Trigger Word",
                placeholder="E.g., 'my_char' (used as [my_char] in prompts)"
            )
            train_btn = gr.Button("🚀 Train LoRA", variant="primary")
        
        with gr.Column():
            output = gr.File(label="Download LoRA")
            gallery = gr.Gallery(label="Training Preview")
    
    train_btn.click(
        train_lora,
        inputs=[images, trigger],
        outputs=output,
        api_name="train"
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)