Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| # ===== 1. Initialize BLIP-2 for Auto-Captioning ===== | |
| def load_blip_model(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| return processor, model, device | |
| processor, model, device = load_blip_model() | |
| def generate_caption(image_path, trigger_word): | |
| image = Image.open(image_path) | |
| inputs = processor(image, return_tensors="pt").to(device, torch.float16) | |
| generated_ids = model.generate(**inputs, max_new_tokens=50) | |
| caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| return f"a photo of [{trigger_word}], {caption}" | |
| # ===== 2. Install Kohya_SS Manually ===== | |
| if not os.path.exists("kohya_ss"): | |
| print("⬇️ Installing Kohya_SS...") | |
| os.system("git clone https://github.com/bmaltais/kohya_ss") | |
| os.system("cd kohya_ss && pip install -r requirements.txt") | |
| os.system("cd kohya_ss && pip install .") | |
| # ===== 3. Training Function ===== | |
| def train_lora(images, trigger_word, progress=gr.Progress()): | |
| progress(0.1, desc="Preparing data...") | |
| # Save images + auto-caption | |
| os.makedirs("train", exist_ok=True) | |
| for i, img in enumerate(progress.tqdm(images, desc="Processing images")): | |
| img_path = f"train/img_{i}.jpg" | |
| img.save(img_path) | |
| caption = generate_caption(img_path, trigger_word) | |
| with open(f"train/img_{i}.txt", "w") as f: | |
| f.write(caption) | |
| # Train LoRA (optimized for HF Spaces T4 GPU) | |
| cmd = """ | |
| python kohya_ss/train_network.py \ | |
| --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ | |
| --train_data_dir="train" \ | |
| --output_dir="output" \ | |
| --resolution=512 \ | |
| --network_dim=32 \ | |
| --lr=1e-4 \ | |
| --max_train_steps=800 \ | |
| --mixed_precision="fp16" \ | |
| --save_precision="fp16" \ | |
| --optimizer_type="AdamW8bit" \ | |
| --xformers | |
| """ | |
| progress(0.8, desc="Training LoRA...") | |
| subprocess.run(cmd, shell=True, check=True) | |
| return "output/lora.safetensors" | |
| # ===== 4. Gradio UI ===== | |
| with gr.Blocks(title="1-Click LoRA Trainer") as demo: | |
| gr.Markdown(""" | |
| ## 🎨 Weights.gg-Style LoRA Trainer | |
| Upload 30 images + set a trigger word to train a custom LoRA. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| images = gr.Files( | |
| label="Upload Character Images (30 max)", | |
| file_types=["image"], | |
| interactive=True | |
| ) | |
| trigger = gr.Textbox( | |
| label="Trigger Word", | |
| placeholder="E.g., 'my_char' (used as [my_char] in prompts)" | |
| ) | |
| train_btn = gr.Button("🚀 Train LoRA", variant="primary") | |
| with gr.Column(): | |
| output = gr.File(label="Download LoRA") | |
| gallery = gr.Gallery(label="Training Preview") | |
| train_btn.click( | |
| train_lora, | |
| inputs=[images, trigger], | |
| outputs=output, | |
| api_name="train" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |