Spaces:
No application file
No application file
| from flask import Blueprint, request, jsonify | |
| from src.extensions import db | |
| from src.models.lora_project import LoRAProject, TrainingStatus | |
| from datetime import datetime | |
| lora_bp = Blueprint("lora_projects", __name__) | |
| def get_projects(): | |
| projects = LoRAProject.query.order_by(LoRAProject.created_at.desc()).all() | |
| return jsonify([project.to_dict() for project in projects]) | |
| def create_project(): | |
| data = request.get_json() | |
| new_project = LoRAProject( | |
| name=data["name"], | |
| description=data.get("description"), | |
| base_model=data["base_model"], | |
| rank=data.get("rank", 4), | |
| alpha=data.get("alpha", 32), | |
| dropout=data.get("dropout", 0.1), | |
| learning_rate=data.get("learning_rate", 1e-4), | |
| batch_size=data.get("batch_size", 1), | |
| num_epochs=data.get("num_epochs", 10), | |
| use_8bit_optimizer=data.get("use_8bit_optimizer", True), | |
| use_gradient_checkpointing=data.get("use_gradient_checkpointing", True), | |
| mixed_precision=data.get("mixed_precision", "fp16"), | |
| dataset_path=data.get("dataset_path"), | |
| num_images=data.get("num_images"), | |
| status=TrainingStatus.PENDING | |
| ) | |
| db.session.add(new_project) | |
| db.session.commit() | |
| return jsonify(new_project.to_dict()), 201 | |
| def get_project(project_id): | |
| project = LoRAProject.query.get_or_404(project_id) | |
| return jsonify(project.to_dict()) | |
| def start_training(project_id): | |
| project = LoRAProject.query.get_or_404(project_id) | |
| if project.status == TrainingStatus.PENDING or project.status == TrainingStatus.FAILED: | |
| project.status = TrainingStatus.RUNNING | |
| project.started_at = datetime.utcnow() | |
| db.session.commit() | |
| # Aqui você chamaria o serviço de treinamento LoRA | |
| # Por enquanto, apenas simulamos o status | |
| return jsonify({"message": f"Training started for project {project.name}"}) | |
| return jsonify({"message": "Project is already running or completed"}), 400 | |
| def get_training_status(project_id): | |
| project = LoRAProject.query.get_or_404(project_id) | |
| return jsonify({"status": project.status.value, "progress": project.progress}) | |
| def get_models(): | |
| # Simulação de modelos disponíveis | |
| models = [ | |
| {"id": "stable-diffusion-v1-5", "name": "Stable Diffusion v1.5 (4GB)", "vram_estimate": 4}, | |
| {"id": "stable-diffusion-xl", "name": "Stable Diffusion XL (7GB)", "vram_estimate": 7}, | |
| {"id": "dialo-gpt-medium", "name": "DialoGPT Medium (1GB)", "vram_estimate": 1}, | |
| ] | |
| return jsonify(models) | |