from flask import Blueprint, request, jsonify from src.extensions import db from src.models.lora_project import LoRAProject, TrainingStatus from datetime import datetime lora_bp = Blueprint("lora_projects", __name__) @lora_bp.route("/projects", methods=["GET"]) def get_projects(): projects = LoRAProject.query.order_by(LoRAProject.created_at.desc()).all() return jsonify([project.to_dict() for project in projects]) @lora_bp.route("/projects", methods=["POST"]) def create_project(): data = request.get_json() new_project = LoRAProject( name=data["name"], description=data.get("description"), base_model=data["base_model"], rank=data.get("rank", 4), alpha=data.get("alpha", 32), dropout=data.get("dropout", 0.1), learning_rate=data.get("learning_rate", 1e-4), batch_size=data.get("batch_size", 1), num_epochs=data.get("num_epochs", 10), use_8bit_optimizer=data.get("use_8bit_optimizer", True), use_gradient_checkpointing=data.get("use_gradient_checkpointing", True), mixed_precision=data.get("mixed_precision", "fp16"), dataset_path=data.get("dataset_path"), num_images=data.get("num_images"), status=TrainingStatus.PENDING ) db.session.add(new_project) db.session.commit() return jsonify(new_project.to_dict()), 201 @lora_bp.route("/projects/", methods=["GET"]) def get_project(project_id): project = LoRAProject.query.get_or_404(project_id) return jsonify(project.to_dict()) @lora_bp.route("/projects//train", methods=["POST"]) def start_training(project_id): project = LoRAProject.query.get_or_404(project_id) if project.status == TrainingStatus.PENDING or project.status == TrainingStatus.FAILED: project.status = TrainingStatus.RUNNING project.started_at = datetime.utcnow() db.session.commit() # Aqui você chamaria o serviço de treinamento LoRA # Por enquanto, apenas simulamos o status return jsonify({"message": f"Training started for project {project.name}"}) return jsonify({"message": "Project is already running or completed"}), 400 @lora_bp.route("/projects//status", methods=["GET"]) def get_training_status(project_id): project = LoRAProject.query.get_or_404(project_id) return jsonify({"status": project.status.value, "progress": project.progress}) @lora_bp.route("/models", methods=["GET"]) def get_models(): # Simulação de modelos disponíveis models = [ {"id": "stable-diffusion-v1-5", "name": "Stable Diffusion v1.5 (4GB)", "vram_estimate": 4}, {"id": "stable-diffusion-xl", "name": "Stable Diffusion XL (7GB)", "vram_estimate": 7}, {"id": "dialo-gpt-medium", "name": "DialoGPT Medium (1GB)", "vram_estimate": 1}, ] return jsonify(models)