PR_IRminiSaaS / db.py
Corin1998's picture
Create db.py
7d17d15 verified
import os
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Enum, ForeignKey, JSON, Boolean
from sqlalchemy.orm import sessionmaker, declarative_base, relationship
from sqlalchemy.sql import func
import enum
def _sqlite_url() -> str:
# HF Spaces: /data が永続、次に /tmp
for path in ("/data/app.db", "/tmp/app.db", "./app.db"):
try:
d = os.path.dirname(path) or "."
os.makedirs(d, exist_ok=True)
test = path + ".touch"
with open(test, "w") as f: f.write("ok")
os.remove(test)
return f"sqlite+pysqlite:///{path}"
except Exception:
continue
return "sqlite+pysqlite:///./app.db"
DATABASE_URL = os.getenv("DATABASE_URL", _sqlite_url())
IS_SQLITE = DATABASE_URL.startswith("sqlite+")
engine = create_engine(
DATABASE_URL,
pool_pre_ping=not IS_SQLITE,
future=True,
connect_args={"check_same_thread": False} if IS_SQLITE else {},
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, future=True)
Base = declarative_base()
class Tone(str, enum.Enum):
neutral = "neutral"
formal = "formal"
friendly = "friendly"
investor = "investor"
pr_bold = "pr_bold"
class ContentType(str, enum.Enum):
press_release = "press_release"
ir_letter = "ir_letter"
investor_summary = "investor_summary"
class DraftStatus(str, enum.Enum):
draft = "draft"
pending = "pending"
approved = "approved"
scheduled = "scheduled"
sent = "sent"
class Draft(Base):
__tablename__ = "drafts"
id = Column(Integer, primary_key=True, index=True)
source_type = Column(String(20))
source_ref = Column(Text)
raw_text = Column(Text)
content_type = Column(Enum(ContentType), nullable=False)
tone = Column(Enum(Tone), nullable=False, default=Tone.neutral)
title = Column(String(300))
body_md = Column(Text)
status = Column(Enum(DraftStatus), nullable=False, default=DraftStatus.draft)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
subject_a = Column(String(300))
subject_b = Column(String(300))
deliver_email_list = Column(Text)
deliver_x = Column(Boolean, default=False)
deliver_note = Column(Boolean, default=False)
deliveries = relationship("Delivery", back_populates="draft")
class Delivery(Base):
__tablename__ = "deliveries"
id = Column(Integer, primary_key=True)
draft_id = Column(Integer, ForeignKey("drafts.id"))
channel = Column(String(10), nullable=False) # email/x/note
payload = Column(JSON)
result = Column(JSON)
created_at = Column(DateTime(timezone=True), server_default=func.now())
draft = relationship("Draft", back_populates="deliveries")
def init_db():
Base.metadata.create_all(bind=engine)