scriptwriter / db.py
kreemyyyy's picture
Upload 13 files
fd88516 verified
# db.py
import os, json, random
from contextlib import contextmanager
from typing import List, Iterable, Tuple, Optional
from sqlmodel import SQLModel, create_engine, Session, select
from datetime import datetime
# ---- Configure DB ----
DB_URL = os.environ.get("DB_URL", "sqlite:///studio.db")
engine = create_engine(DB_URL, echo=False)
# ---- Models ----
from models import Script, Rating # make sure Script has: is_reference: bool, plus the other fields
# ---- Init / Session ----
def init_db() -> None:
SQLModel.metadata.create_all(engine)
@contextmanager
def get_session():
with Session(engine) as ses:
yield ses
# ---- Helpers for import ----
def _payload_from_jsonl_row(row: dict) -> Tuple[dict, str, str]:
"""
Map a JSONL row (the file I generated for you) into Script columns.
Returns (payload, dedupe_key_title, dedupe_key_creator).
You can also add 'external_id' to Script model and dedupe on that.
"""
# Prefer using the JSON 'id' as an external identifier:
external_id = row.get("id", "")
# Tone could be an array; flatten for now
tone = ", ".join(row.get("tonality", [])) or "playful"
# Compact caption: use caption options line as a quick reference
caption = " | ".join(row.get("caption_options", []))[:180]
payload = dict(
# core identity
creator=row.get("model_name", "Unknown"),
content_type=(row.get("video_type", "") or "talking_style").lower(),
tone=tone,
title=external_id or row.get("theme", "") or "Imported Script",
hook=row.get("video_hook") or "",
# structured fields
beats=row.get("storyboard", []) or [],
voiceover="",
caption=caption,
hashtags=row.get("hashtags", []) or [],
cta="",
# flags
source="import",
is_reference=True, # mark imported examples as references
compliance="pass", # we'll score again after save if you want
)
return payload, payload["title"], payload["creator"]
def _score_and_update_compliance(s: Script) -> None:
"""Optional: score compliance using your simple rule-checker."""
try:
from compliance import blob_from, score_script
lvl, _ = score_script(blob_from(s.dict()))
s.compliance = lvl
except Exception:
# If no compliance module or error, keep default
pass
def _iter_jsonl(path: str) -> Iterable[dict]:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
# ---- Public: Importer ----
def import_jsonl(path: str) -> int:
"""
Import (upsert) scripts from a JSONL file produced earlier.
Dedupe by (creator, title). Returns count of upserted rows.
"""
init_db()
count = 0
with get_session() as ses:
for row in _iter_jsonl(path):
payload, key_title, key_creator = _payload_from_jsonl_row(row)
existing = ses.exec(
select(Script).where(
Script.title == key_title,
Script.creator == key_creator
)
).first()
if existing:
# Update all fields
for k, v in payload.items():
setattr(existing, k, v)
_score_and_update_compliance(existing)
existing.updated_at = datetime.utcnow()
ses.add(existing)
else:
obj = Script(**payload)
_score_and_update_compliance(obj)
ses.add(obj)
count += 1
ses.commit()
return count
# ---- Ratings API ----
def add_rating(script_id: int,
overall: float,
hook: Optional[float] = None,
originality: Optional[float] = None,
style_fit: Optional[float] = None,
safety: Optional[float] = None,
notes: Optional[str] = None,
rater: str = "human") -> None:
with get_session() as ses:
# store rating event
ses.add(Rating(
script_id=script_id, overall=overall, hook=hook,
originality=originality, style_fit=style_fit, safety=safety,
notes=notes, rater=rater
))
ses.commit()
# recompute cached aggregates on Script
_recompute_script_aggregates(ses, script_id)
ses.commit()
def _recompute_script_aggregates(ses: Session, script_id: int) -> None:
rows = list(ses.exec(select(Rating).where(Rating.script_id == script_id)))
if not rows:
return
def avg(field):
vals = [getattr(r, field) for r in rows if getattr(r, field) is not None]
return round(sum(vals)/len(vals), 3) if vals else None
s: Script = ses.get(Script, script_id)
s.score_overall = avg("overall")
s.score_hook = avg("hook")
s.score_originality = avg("originality")
s.score_style_fit = avg("style_fit")
s.score_safety = avg("safety")
s.ratings_count = len(rows)
s.updated_at = datetime.utcnow()
ses.add(s)
# ---- Public: Reference retrieval for generation ----
def extract_snippets_from_script(s: Script, max_lines: int = 3) -> List[str]:
items: List[str] = []
if s.hook:
items.append(s.hook.strip())
if s.beats:
items.extend([b.strip() for b in s.beats[:2]]) # first 1–2 beats
if s.caption:
items.append(s.caption.strip()[:120])
# dedupe while preserving order
seen, uniq = set(), []
for it in items:
if it and it not in seen:
uniq.append(it); seen.add(it)
return uniq[:max_lines]
def get_library_refs(creator: str, content_type: str, k: int = 6) -> List[str]:
with get_session() as ses:
rows = list(ses.exec(
select(Script)
.where(
Script.creator == creator,
Script.content_type == content_type,
Script.is_reference == True,
Script.compliance != "fail"
)
.order_by(Script.created_at.desc())
))[:k]
snippets: List[str] = []
for r in rows:
snippets.extend(extract_snippets_from_script(r))
# final dedupe
seen, uniq = set(), []
for s in snippets:
if s not in seen:
uniq.append(s); seen.add(s)
return uniq[:8]
# ---- HYBRID reference retrieval ----
def get_hybrid_refs(creator: str, content_type: str, k: int = 6,
top_n: int = 3, explore_n: int = 2, newest_n: int = 1) -> List[str]:
"""
Mix of:
- top_n best scored references (exploit)
- explore_n random references (explore)
- newest_n most recent references (freshness)
Returns flattened snippet list (cap ~8 to keep prompt lean).
"""
with get_session() as ses:
all_refs = list(ses.exec(
select(Script).where(
Script.creator == creator,
Script.content_type == content_type,
Script.is_reference == True,
Script.compliance != "fail"
)
))
if not all_refs:
return []
# sort by score_overall (fallback to 0) and pick top_n
scored = sorted(all_refs, key=lambda s: (s.score_overall or 0.0), reverse=True)
best = scored[:top_n]
# newest by created_at
newest = sorted(all_refs, key=lambda s: s.created_at, reverse=True)[:newest_n]
# explore = random sample from the remainder
remainder = [r for r in all_refs if r not in best and r not in newest]
explore = random.sample(remainder, min(explore_n, len(remainder))) if remainder else []
# merge (preserve order, dedupe)
chosen_scripts = []
seen_ids = set()
for bucket in (best, explore, newest):
for s in bucket:
if s.id not in seen_ids:
chosen_scripts.append(s)
seen_ids.add(s.id)
# cut to k scripts
chosen_scripts = chosen_scripts[:k]
# flatten snippets and cap to keep prompt compact
snippets: List[str] = []
for s in chosen_scripts:
snippets.extend(extract_snippets_from_script(s))
# dedupe again and cap ~8 lines
seen, out = set(), []
for sn in snippets:
if sn not in seen:
out.append(sn); seen.add(sn)
return out[:8]