from src.extensions import db from datetime import datetime from enum import Enum from sqlalchemy.orm import Mapped, mapped_column class TrainingStatus(Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class LoRAProject(db.Model): __tablename__ = 'lora_projects' id: Mapped[int] = mapped_column(db.Integer, primary_key=True) name: Mapped[str] = mapped_column(db.String(100), nullable=False) description: Mapped[str] = mapped_column(db.Text, nullable=True) # Model configuration base_model: Mapped[str] = mapped_column(db.String(200), nullable=False) # e.g., "runwayml/stable-diffusion-v1-5" # LoRA parameters rank: Mapped[int] = mapped_column(db.Integer, default=4) alpha: Mapped[int] = mapped_column(db.Integer, default=32) dropout: Mapped[float] = mapped_column(db.Float, default=0.1) # Training parameters learning_rate: Mapped[float] = mapped_column(db.Float, default=1e-4) batch_size: Mapped[int] = mapped_column(db.Integer, default=1) num_epochs: Mapped[int] = mapped_column(db.Integer, default=10) # Optimization settings use_8bit_optimizer: Mapped[bool] = mapped_column(db.Boolean, default=True) use_gradient_checkpointing: Mapped[bool] = mapped_column(db.Boolean, default=True) mixed_precision: Mapped[str] = mapped_column(db.String(10), default="fp16") # fp16, bf16, fp32 # Dataset information dataset_path: Mapped[str] = mapped_column(db.String(500), nullable=True) num_images: Mapped[int] = mapped_column(db.Integer, nullable=True) # Training status and results status: Mapped[TrainingStatus] = mapped_column(db.Enum(TrainingStatus), default=TrainingStatus.PENDING) progress: Mapped[float] = mapped_column(db.Float, default=0.0) # 0.0 to 1.0 current_epoch: Mapped[int] = mapped_column(db.Integer, default=0) current_loss: Mapped[float] = mapped_column(db.Float, nullable=True) # File paths output_path: Mapped[str] = mapped_column(db.String(500), nullable=True) log_file: Mapped[str] = mapped_column(db.String(500), nullable=True) # Timestamps created_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow) started_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=True) completed_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=True) # Error information error_message: Mapped[str] = mapped_column(db.Text, nullable=True) def __repr__(self): return f'' def to_dict(self): return { 'id': self.id, 'name': self.name, 'description': self.description, 'base_model': self.base_model, 'rank': self.rank, 'alpha': self.alpha, 'dropout': self.dropout, 'learning_rate': self.learning_rate, 'batch_size': self.batch_size, 'num_epochs': self.num_epochs, 'use_8bit_optimizer': self.use_8bit_optimizer, 'use_gradient_checkpointing': self.use_gradient_checkpointing, 'mixed_precision': self.mixed_precision, 'dataset_path': self.dataset_path, 'num_images': self.num_images, 'status': self.status.value if self.status else None, 'progress': self.progress, 'current_epoch': self.current_epoch, 'current_loss': self.current_loss, 'output_path': self.output_path, 'log_file': self.log_file, 'created_at': self.created_at.isoformat() if self.created_at else None, 'started_at': self.started_at.isoformat() if self.started_at else None, 'completed_at': self.completed_at.isoformat() if self.completed_at else None, 'error_message': self.error_message }