Lora-trainer-all / lora_training.py
Allex21's picture
Upload 12 files
7c8a29e verified
import os
import uuid
import json
import threading
import subprocess
import time
from datetime import datetime
from flask import Blueprint, request, jsonify, send_file
from werkzeug.utils import secure_filename
import shutil
from src.lora_trainer import create_lora_trainer, validate_training_config
lora_bp = Blueprint('lora', __name__)
# Configurações
UPLOAD_FOLDER = '/tmp/lora_uploads'
TRAINING_FOLDER = '/tmp/lora_training'
RESULTS_FOLDER = '/tmp/lora_results'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp', 'bmp'}
# Armazenamento em memória para status de treinamento
training_status = {}
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def ensure_directories():
"""Garante que os diretórios necessários existam"""
for folder in [UPLOAD_FOLDER, TRAINING_FOLDER, RESULTS_FOLDER]:
os.makedirs(folder, exist_ok=True)
def run_training_process(training_id, config):
"""Executa o processo de treinamento LoRA em thread separada"""
try:
# Atualizar status
training_status[training_id]['status'] = 'preparing'
training_status[training_id]['progress'] = 10
training_status[training_id]['message'] = 'Preparando ambiente de treinamento...'
# Criar diretório de treinamento
training_dir = os.path.join(TRAINING_FOLDER, training_id)
os.makedirs(training_dir, exist_ok=True)
# Copiar imagens para diretório de treinamento
images_dir = os.path.join(training_dir, 'images')
os.makedirs(images_dir, exist_ok=True)
upload_dir = os.path.join(UPLOAD_FOLDER, training_id)
for filename in os.listdir(upload_dir):
if allowed_file(filename):
src = os.path.join(upload_dir, filename)
dst = os.path.join(images_dir, filename)
shutil.copy2(src, dst)
training_status[training_id]['progress'] = 20
training_status[training_id]['message'] = 'Preparando dataset...'
# Configurar paths para treinamento
output_dir = os.path.join(RESULTS_FOLDER, training_id)
os.makedirs(output_dir, exist_ok=True)
config['images_dir'] = images_dir
config['output_dir'] = output_dir
# Validar configuração
is_valid, validation_message = validate_training_config(config)
if not is_valid:
raise Exception(f"Configuração inválida: {validation_message}")
training_status[training_id]['progress'] = 30
training_status[training_id]['message'] = 'Iniciando treinamento LoRA...'
# Callback para atualizar progresso
def progress_callback(progress, message):
training_status[training_id]['progress'] = max(30, min(90, progress))
training_status[training_id]['message'] = message
training_status[training_id]['logs'] = training_status[training_id].get('logs', []) + [message]
# Criar e executar trainer
trainer = create_lora_trainer(config)
# Executar treinamento
trainer.train(progress_callback=progress_callback)
# Atualizar logs
training_status[training_id]['logs'] = trainer.training_logs
training_status[training_id]['status'] = 'completed'
training_status[training_id]['progress'] = 100
training_status[training_id]['message'] = 'Treinamento concluído com sucesso!'
training_status[training_id]['completed'] = True
# Criar arquivos adicionais
create_additional_files(training_id, config, output_dir)
# Criar links de download
download_links = []
for filename in os.listdir(output_dir):
if os.path.isfile(os.path.join(output_dir, filename)):
download_links.append({
'name': filename,
'url': f'/api/download/{training_id}/{filename}'
})
training_status[training_id]['download_links'] = download_links
training_status[training_id]['trigger_word'] = config['trigger_word']
except Exception as e:
training_status[training_id]['status'] = 'error'
training_status[training_id]['error'] = str(e)
training_status[training_id]['message'] = f'Erro durante treinamento: {str(e)}'
def create_additional_files(training_id, config, output_dir):
"""Cria arquivos adicionais de resultado"""
# Criar README com instruções detalhadas
readme_content = f'''# LoRA: {config['character_name']}
## Informações do Treinamento
- **Personagem**: {config['character_name']}
- **Trigger Word**: {config['trigger_word']}
- **Resolução**: {config['resolution']}x{config['resolution']}
- **Rank**: {config['rank']}
- **Learning Rate**: {config['learning_rate']}
- **Épocas**: {config['epochs']}
- **Data de Treinamento**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
## Como Usar
### ComfyUI
1. Coloque o arquivo `pytorch_lora_weights.safetensors` na pasta `models/loras/` do ComfyUI
2. No workflow, adicione um nó "Load LoRA"
3. Selecione o arquivo LoRA
4. Use a trigger word "{config['trigger_word']}" em seus prompts
5. Ajuste o peso entre 0.7 e 1.0
### Automatic1111
1. Coloque o arquivo `pytorch_lora_weights.safetensors` na pasta `models/Lora/`
2. Use a sintaxe `<lora:pytorch_lora_weights:0.8>` no prompt
3. Inclua a trigger word "{config['trigger_word']}" no prompt
### SeaArt
1. Faça upload do arquivo LoRA na plataforma
2. Selecione o LoRA em suas gerações
3. Use a trigger word "{config['trigger_word']}" no prompt
## Exemplos de Prompts
- "{config['trigger_word']}, portrait, high quality"
- "{config['trigger_word']}, full body, standing, detailed"
- "{config['trigger_word']}, close-up, beautiful lighting"
- "{config['trigger_word']}, anime style, colorful"
- "{config['trigger_word']}, realistic, photographic"
## Dicas de Uso
- Use peso entre 0.7-1.0 para melhores resultados
- Combine com outros LoRAs para estilos específicos
- Experimente diferentes CFG scales (7-12)
- Para resultados mais consistentes, use a trigger word no início do prompt
- Ajuste o peso do LoRA conforme necessário para cada geração
## Compatibilidade
Este LoRA é compatível com:
- Stable Diffusion 1.5
- ComfyUI
- Automatic1111 WebUI
- SeaArt
- Fooocus
- InvokeAI
- E outras ferramentas que suportam LoRA
## Suporte
Para dúvidas ou problemas, consulte a documentação da ferramenta utilizada.
'''
readme_file = os.path.join(output_dir, 'README.md')
with open(readme_file, 'w', encoding='utf-8') as f:
f.write(readme_content)
# Criar arquivo de metadados
metadata = {
"model_name": config['character_name'],
"trigger_word": config['trigger_word'],
"base_model": "runwayml/stable-diffusion-v1-5",
"training_config": config,
"created_at": datetime.now().isoformat(),
"version": "1.0",
"type": "character_lora",
"tags": ["character", "lora", "consistent", config['character_name']],
"description": f"LoRA treinado para o personagem {config['character_name']} usando a trigger word '{config['trigger_word']}'",
"usage_instructions": {
"trigger_word": config['trigger_word'],
"recommended_weight": "0.7-1.0",
"compatible_models": ["SD1.5"],
"example_prompts": [
f"{config['trigger_word']}, portrait, high quality",
f"{config['trigger_word']}, full body, detailed",
f"{config['trigger_word']}, close-up, beautiful lighting"
]
}
}
metadata_file = os.path.join(output_dir, 'metadata.json')
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
@lora_bp.route('/train', methods=['POST'])
def start_training():
"""Inicia o treinamento LoRA"""
try:
ensure_directories()
# Gerar ID único para o treinamento
training_id = str(uuid.uuid4())
# Verificar se há imagens
if 'images' not in request.files:
return jsonify({'success': False, 'message': 'Nenhuma imagem foi enviada'}), 400
files = request.files.getlist('images')
if len(files) < 5:
return jsonify({'success': False, 'message': 'Mínimo de 5 imagens necessárias'}), 400
# Criar diretório para upload
upload_dir = os.path.join(UPLOAD_FOLDER, training_id)
os.makedirs(upload_dir, exist_ok=True)
# Salvar imagens
saved_files = []
for file in files:
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(upload_dir, filename)
file.save(filepath)
saved_files.append(filename)
if len(saved_files) < 5:
return jsonify({'success': False, 'message': 'Pelo menos 5 imagens válidas são necessárias'}), 400
# Obter configurações
config = {
'character_name': request.form.get('character_name', '').strip(),
'trigger_word': request.form.get('trigger_word', '').strip(),
'resolution': request.form.get('resolution', '512'),
'learning_rate': request.form.get('learning_rate', '1e-4'),
'rank': request.form.get('rank', '16'),
'epochs': request.form.get('epochs', '20'),
'description': request.form.get('description', '').strip(),
'images': saved_files,
'training_id': training_id,
'created_at': datetime.now().isoformat()
}
# Validar configurações obrigatórias
if not config['character_name']:
return jsonify({'success': False, 'message': 'Nome do personagem é obrigatório'}), 400
if not config['trigger_word']:
return jsonify({'success': False, 'message': 'Trigger word é obrigatória'}), 400
# Inicializar status do treinamento
training_status[training_id] = {
'status': 'starting',
'progress': 0,
'message': 'Iniciando treinamento...',
'logs': [],
'completed': False,
'error': None,
'config': config
}
# Iniciar treinamento em thread separada
training_thread = threading.Thread(
target=run_training_process,
args=(training_id, config)
)
training_thread.daemon = True
training_thread.start()
return jsonify({
'success': True,
'training_id': training_id,
'message': 'Treinamento iniciado com sucesso'
})
except Exception as e:
return jsonify({'success': False, 'message': f'Erro interno: {str(e)}'}), 500
@lora_bp.route('/training-status/<training_id>', methods=['GET'])
def get_training_status(training_id):
"""Retorna o status do treinamento"""
if training_id not in training_status:
return jsonify({'error': 'Treinamento não encontrado'}), 404
return jsonify(training_status[training_id])
@lora_bp.route('/download/<training_id>/<filename>', methods=['GET'])
def download_file(training_id, filename):
"""Download de arquivos de resultado"""
try:
result_dir = os.path.join(RESULTS_FOLDER, training_id)
file_path = os.path.join(result_dir, filename)
if not os.path.exists(file_path):
return jsonify({'error': 'Arquivo não encontrado'}), 404
return send_file(file_path, as_attachment=True, download_name=filename)
except Exception as e:
return jsonify({'error': f'Erro ao baixar arquivo: {str(e)}'}), 500
@lora_bp.route('/trainings', methods=['GET'])
def list_trainings():
"""Lista todos os treinamentos"""
trainings = []
for training_id, status in training_status.items():
trainings.append({
'id': training_id,
'character_name': status.get('config', {}).get('character_name', 'Desconhecido'),
'status': status.get('status', 'unknown'),
'progress': status.get('progress', 0),
'created_at': status.get('config', {}).get('created_at', '')
})
return jsonify({'trainings': trainings})
@lora_bp.route('/delete-training/<training_id>', methods=['DELETE'])
def delete_training(training_id):
"""Remove um treinamento e seus arquivos"""
try:
# Remover do status
if training_id in training_status:
del training_status[training_id]
# Remover diretórios
for base_dir in [UPLOAD_FOLDER, TRAINING_FOLDER, RESULTS_FOLDER]:
training_dir = os.path.join(base_dir, training_id)
if os.path.exists(training_dir):
shutil.rmtree(training_dir)
return jsonify({'success': True, 'message': 'Treinamento removido com sucesso'})
except Exception as e:
return jsonify({'success': False, 'message': f'Erro ao remover treinamento: {str(e)}'}), 500