Spaces:
No application file
No application file
| from src.extensions import db | |
| from datetime import datetime | |
| from enum import Enum | |
| from sqlalchemy.orm import Mapped, mapped_column | |
| class TrainingStatus(Enum): | |
| PENDING = "pending" | |
| RUNNING = "running" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| CANCELLED = "cancelled" | |
| class LoRAProject(db.Model): | |
| __tablename__ = 'lora_projects' | |
| id: Mapped[int] = mapped_column(db.Integer, primary_key=True) | |
| name: Mapped[str] = mapped_column(db.String(100), nullable=False) | |
| description: Mapped[str] = mapped_column(db.Text, nullable=True) | |
| # Model configuration | |
| base_model: Mapped[str] = mapped_column(db.String(200), nullable=False) # e.g., "runwayml/stable-diffusion-v1-5" | |
| # LoRA parameters | |
| rank: Mapped[int] = mapped_column(db.Integer, default=4) | |
| alpha: Mapped[int] = mapped_column(db.Integer, default=32) | |
| dropout: Mapped[float] = mapped_column(db.Float, default=0.1) | |
| # Training parameters | |
| learning_rate: Mapped[float] = mapped_column(db.Float, default=1e-4) | |
| batch_size: Mapped[int] = mapped_column(db.Integer, default=1) | |
| num_epochs: Mapped[int] = mapped_column(db.Integer, default=10) | |
| # Optimization settings | |
| use_8bit_optimizer: Mapped[bool] = mapped_column(db.Boolean, default=True) | |
| use_gradient_checkpointing: Mapped[bool] = mapped_column(db.Boolean, default=True) | |
| mixed_precision: Mapped[str] = mapped_column(db.String(10), default="fp16") # fp16, bf16, fp32 | |
| # Dataset information | |
| dataset_path: Mapped[str] = mapped_column(db.String(500), nullable=True) | |
| num_images: Mapped[int] = mapped_column(db.Integer, nullable=True) | |
| # Training status and results | |
| status: Mapped[TrainingStatus] = mapped_column(db.Enum(TrainingStatus), default=TrainingStatus.PENDING) | |
| progress: Mapped[float] = mapped_column(db.Float, default=0.0) # 0.0 to 1.0 | |
| current_epoch: Mapped[int] = mapped_column(db.Integer, default=0) | |
| current_loss: Mapped[float] = mapped_column(db.Float, nullable=True) | |
| # File paths | |
| output_path: Mapped[str] = mapped_column(db.String(500), nullable=True) | |
| log_file: Mapped[str] = mapped_column(db.String(500), nullable=True) | |
| # Timestamps | |
| created_at: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow) | |
| started_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=True) | |
| completed_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=True) | |
| # Error information | |
| error_message: Mapped[str] = mapped_column(db.Text, nullable=True) | |
| def __repr__(self): | |
| return f'<LoRAProject {self.name}>' | |
| def to_dict(self): | |
| return { | |
| 'id': self.id, | |
| 'name': self.name, | |
| 'description': self.description, | |
| 'base_model': self.base_model, | |
| 'rank': self.rank, | |
| 'alpha': self.alpha, | |
| 'dropout': self.dropout, | |
| 'learning_rate': self.learning_rate, | |
| 'batch_size': self.batch_size, | |
| 'num_epochs': self.num_epochs, | |
| 'use_8bit_optimizer': self.use_8bit_optimizer, | |
| 'use_gradient_checkpointing': self.use_gradient_checkpointing, | |
| 'mixed_precision': self.mixed_precision, | |
| 'dataset_path': self.dataset_path, | |
| 'num_images': self.num_images, | |
| 'status': self.status.value if self.status else None, | |
| 'progress': self.progress, | |
| 'current_epoch': self.current_epoch, | |
| 'current_loss': self.current_loss, | |
| 'output_path': self.output_path, | |
| 'log_file': self.log_file, | |
| 'created_at': self.created_at.isoformat() if self.created_at else None, | |
| 'started_at': self.started_at.isoformat() if self.started_at else None, | |
| 'completed_at': self.completed_at.isoformat() if self.completed_at else None, | |
| 'error_message': self.error_message | |
| } | |