File size: 2,889 Bytes
5bb2330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from flask import Blueprint, request, jsonify
from src.extensions import db
from src.models.lora_project import LoRAProject, TrainingStatus
from datetime import datetime

lora_bp = Blueprint("lora_projects", __name__)

@lora_bp.route("/projects", methods=["GET"])
def get_projects():
    projects = LoRAProject.query.order_by(LoRAProject.created_at.desc()).all()
    return jsonify([project.to_dict() for project in projects])

@lora_bp.route("/projects", methods=["POST"])
def create_project():
    data = request.get_json()
    new_project = LoRAProject(
        name=data["name"],
        description=data.get("description"),
        base_model=data["base_model"],
        rank=data.get("rank", 4),
        alpha=data.get("alpha", 32),
        dropout=data.get("dropout", 0.1),
        learning_rate=data.get("learning_rate", 1e-4),
        batch_size=data.get("batch_size", 1),
        num_epochs=data.get("num_epochs", 10),
        use_8bit_optimizer=data.get("use_8bit_optimizer", True),
        use_gradient_checkpointing=data.get("use_gradient_checkpointing", True),
        mixed_precision=data.get("mixed_precision", "fp16"),
        dataset_path=data.get("dataset_path"),
        num_images=data.get("num_images"),
        status=TrainingStatus.PENDING
    )
    db.session.add(new_project)
    db.session.commit()
    return jsonify(new_project.to_dict()), 201

@lora_bp.route("/projects/<int:project_id>", methods=["GET"])
def get_project(project_id):
    project = LoRAProject.query.get_or_404(project_id)
    return jsonify(project.to_dict())

@lora_bp.route("/projects/<int:project_id>/train", methods=["POST"])
def start_training(project_id):
    project = LoRAProject.query.get_or_404(project_id)
    if project.status == TrainingStatus.PENDING or project.status == TrainingStatus.FAILED:
        project.status = TrainingStatus.RUNNING
        project.started_at = datetime.utcnow()
        db.session.commit()
        # Aqui você chamaria o serviço de treinamento LoRA
        # Por enquanto, apenas simulamos o status
        return jsonify({"message": f"Training started for project {project.name}"})
    return jsonify({"message": "Project is already running or completed"}), 400

@lora_bp.route("/projects/<int:project_id>/status", methods=["GET"])
def get_training_status(project_id):
    project = LoRAProject.query.get_or_404(project_id)
    return jsonify({"status": project.status.value, "progress": project.progress})

@lora_bp.route("/models", methods=["GET"])
def get_models():
    # Simulação de modelos disponíveis
    models = [
        {"id": "stable-diffusion-v1-5", "name": "Stable Diffusion v1.5 (4GB)", "vram_estimate": 4},
        {"id": "stable-diffusion-xl", "name": "Stable Diffusion XL (7GB)", "vram_estimate": 7},
        {"id": "dialo-gpt-medium", "name": "DialoGPT Medium (1GB)", "vram_estimate": 1},
    ]
    return jsonify(models)