Pose-Detection-App / database.py
vertalius's picture
Update database.py
d9c1c4a verified
import os
from sqlalchemy import create_engine, Column, Integer, String, JSON, ForeignKey, DateTime, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from datetime import datetime
# Create database engine
DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///./database.db') # Default to SQLite
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class ProcessedFile(Base):
__tablename__ = "processed_files"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String, nullable=False)
file_type = Column(String, nullable=False) # 'image' or 'video'
processing_status = Column(String, nullable=False, default="pending")
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
pose_data = relationship("PoseData", back_populates="file")
animation_data = relationship("AnimationData", back_populates="file")
class PoseData(Base):
__tablename__ = "pose_data"
id = Column(Integer, primary_key=True, index=True)
file_id = Column(Integer, ForeignKey("processed_files.id"), nullable=False)
frame_number = Column(Integer, default=0) # 0 for images, frame number for videos
landmarks = Column(JSON, nullable=False)
corrected_landmarks = Column(JSON)
is_corrected = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
file = relationship("ProcessedFile", back_populates="pose_data")
class AnimationData(Base):
__tablename__ = "animation_data"
id = Column(Integer, primary_key=True, index=True)
file_id = Column(Integer, ForeignKey("processed_files.id"), nullable=False)
skeleton_data = Column(JSON, nullable=False)
export_format = Column(String, nullable=False, default="unreal")
created_at = Column(DateTime, default=datetime.utcnow)
file = relationship("ProcessedFile", back_populates="animation_data")
# Create all tables
Base.metadata.create_all(bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()