Lora-trainer / lora_projects.py
Allex21's picture
Upload 24 files
5bb2330 verified
from flask import Blueprint, request, jsonify
from src.extensions import db
from src.models.lora_project import LoRAProject, TrainingStatus
from datetime import datetime
lora_bp = Blueprint("lora_projects", __name__)
@lora_bp.route("/projects", methods=["GET"])
def get_projects():
projects = LoRAProject.query.order_by(LoRAProject.created_at.desc()).all()
return jsonify([project.to_dict() for project in projects])
@lora_bp.route("/projects", methods=["POST"])
def create_project():
data = request.get_json()
new_project = LoRAProject(
name=data["name"],
description=data.get("description"),
base_model=data["base_model"],
rank=data.get("rank", 4),
alpha=data.get("alpha", 32),
dropout=data.get("dropout", 0.1),
learning_rate=data.get("learning_rate", 1e-4),
batch_size=data.get("batch_size", 1),
num_epochs=data.get("num_epochs", 10),
use_8bit_optimizer=data.get("use_8bit_optimizer", True),
use_gradient_checkpointing=data.get("use_gradient_checkpointing", True),
mixed_precision=data.get("mixed_precision", "fp16"),
dataset_path=data.get("dataset_path"),
num_images=data.get("num_images"),
status=TrainingStatus.PENDING
)
db.session.add(new_project)
db.session.commit()
return jsonify(new_project.to_dict()), 201
@lora_bp.route("/projects/<int:project_id>", methods=["GET"])
def get_project(project_id):
project = LoRAProject.query.get_or_404(project_id)
return jsonify(project.to_dict())
@lora_bp.route("/projects/<int:project_id>/train", methods=["POST"])
def start_training(project_id):
project = LoRAProject.query.get_or_404(project_id)
if project.status == TrainingStatus.PENDING or project.status == TrainingStatus.FAILED:
project.status = TrainingStatus.RUNNING
project.started_at = datetime.utcnow()
db.session.commit()
# Aqui você chamaria o serviço de treinamento LoRA
# Por enquanto, apenas simulamos o status
return jsonify({"message": f"Training started for project {project.name}"})
return jsonify({"message": "Project is already running or completed"}), 400
@lora_bp.route("/projects/<int:project_id>/status", methods=["GET"])
def get_training_status(project_id):
project = LoRAProject.query.get_or_404(project_id)
return jsonify({"status": project.status.value, "progress": project.progress})
@lora_bp.route("/models", methods=["GET"])
def get_models():
# Simulação de modelos disponíveis
models = [
{"id": "stable-diffusion-v1-5", "name": "Stable Diffusion v1.5 (4GB)", "vram_estimate": 4},
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL (7GB)", "vram_estimate": 7},
{"id": "dialo-gpt-medium", "name": "DialoGPT Medium (1GB)", "vram_estimate": 1},
]
return jsonify(models)