Lora-trainer / lora_project.py
Allex21's picture
Upload 24 files
5bb2330 verified
from src.extensions import db
from datetime import datetime
from enum import Enum
from sqlalchemy.orm import Mapped, mapped_column
class TrainingStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class LoRAProject(db.Model):
__tablename__ = 'lora_projects'
id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
name: Mapped[str] = mapped_column(db.String(100), nullable=False)
description: Mapped[str] = mapped_column(db.Text, nullable=True)
# Model configuration
base_model: Mapped[str] = mapped_column(db.String(200), nullable=False) # e.g., "runwayml/stable-diffusion-v1-5"
# LoRA parameters
rank: Mapped[int] = mapped_column(db.Integer, default=4)
alpha: Mapped[int] = mapped_column(db.Integer, default=32)
dropout: Mapped[float] = mapped_column(db.Float, default=0.1)
# Training parameters
learning_rate: Mapped[float] = mapped_column(db.Float, default=1e-4)
batch_size: Mapped[int] = mapped_column(db.Integer, default=1)
num_epochs: Mapped[int] = mapped_column(db.Integer, default=10)
# Optimization settings
use_8bit_optimizer: Mapped[bool] = mapped_column(db.Boolean, default=True)
use_gradient_checkpointing: Mapped[bool] = mapped_column(db.Boolean, default=True)
mixed_precision: Mapped[str] = mapped_column(db.String(10), default="fp16") # fp16, bf16, fp32
# Dataset information
dataset_path: Mapped[str] = mapped_column(db.String(500), nullable=True)
num_images: Mapped[int] = mapped_column(db.Integer, nullable=True)
# Training status and results
status: Mapped[TrainingStatus] = mapped_column(db.Enum(TrainingStatus), default=TrainingStatus.PENDING)
progress: Mapped[float] = mapped_column(db.Float, default=0.0) # 0.0 to 1.0
current_epoch: Mapped[int] = mapped_column(db.Integer, default=0)
current_loss: Mapped[float] = mapped_column(db.Float, nullable=True)
# File paths
output_path: Mapped[str] = mapped_column(db.String(500), nullable=True)
log_file: Mapped[str] = mapped_column(db.String(500), nullable=True)
# Timestamps
created_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow)
started_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=True)
completed_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=True)
# Error information
error_message: Mapped[str] = mapped_column(db.Text, nullable=True)
def __repr__(self):
return f'<LoRAProject {self.name}>'
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'description': self.description,
'base_model': self.base_model,
'rank': self.rank,
'alpha': self.alpha,
'dropout': self.dropout,
'learning_rate': self.learning_rate,
'batch_size': self.batch_size,
'num_epochs': self.num_epochs,
'use_8bit_optimizer': self.use_8bit_optimizer,
'use_gradient_checkpointing': self.use_gradient_checkpointing,
'mixed_precision': self.mixed_precision,
'dataset_path': self.dataset_path,
'num_images': self.num_images,
'status': self.status.value if self.status else None,
'progress': self.progress,
'current_epoch': self.current_epoch,
'current_loss': self.current_loss,
'output_path': self.output_path,
'log_file': self.log_file,
'created_at': self.created_at.isoformat() if self.created_at else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'error_message': self.error_message
}