Spaces:
Sleeping
Sleeping
File size: 8,634 Bytes
fd88516 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# db.py
import os, json, random
from contextlib import contextmanager
from typing import List, Iterable, Tuple, Optional
from sqlmodel import SQLModel, create_engine, Session, select
from datetime import datetime
# ---- Configure DB ----
DB_URL = os.environ.get("DB_URL", "sqlite:///studio.db")
engine = create_engine(DB_URL, echo=False)
# ---- Models ----
from models import Script, Rating # make sure Script has: is_reference: bool, plus the other fields
# ---- Init / Session ----
def init_db() -> None:
SQLModel.metadata.create_all(engine)
@contextmanager
def get_session():
with Session(engine) as ses:
yield ses
# ---- Helpers for import ----
def _payload_from_jsonl_row(row: dict) -> Tuple[dict, str, str]:
"""
Map a JSONL row (the file I generated for you) into Script columns.
Returns (payload, dedupe_key_title, dedupe_key_creator).
You can also add 'external_id' to Script model and dedupe on that.
"""
# Prefer using the JSON 'id' as an external identifier:
external_id = row.get("id", "")
# Tone could be an array; flatten for now
tone = ", ".join(row.get("tonality", [])) or "playful"
# Compact caption: use caption options line as a quick reference
caption = " | ".join(row.get("caption_options", []))[:180]
payload = dict(
# core identity
creator=row.get("model_name", "Unknown"),
content_type=(row.get("video_type", "") or "talking_style").lower(),
tone=tone,
title=external_id or row.get("theme", "") or "Imported Script",
hook=row.get("video_hook") or "",
# structured fields
beats=row.get("storyboard", []) or [],
voiceover="",
caption=caption,
hashtags=row.get("hashtags", []) or [],
cta="",
# flags
source="import",
is_reference=True, # mark imported examples as references
compliance="pass", # we'll score again after save if you want
)
return payload, payload["title"], payload["creator"]
def _score_and_update_compliance(s: Script) -> None:
"""Optional: score compliance using your simple rule-checker."""
try:
from compliance import blob_from, score_script
lvl, _ = score_script(blob_from(s.dict()))
s.compliance = lvl
except Exception:
# If no compliance module or error, keep default
pass
def _iter_jsonl(path: str) -> Iterable[dict]:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
# ---- Public: Importer ----
def import_jsonl(path: str) -> int:
"""
Import (upsert) scripts from a JSONL file produced earlier.
Dedupe by (creator, title). Returns count of upserted rows.
"""
init_db()
count = 0
with get_session() as ses:
for row in _iter_jsonl(path):
payload, key_title, key_creator = _payload_from_jsonl_row(row)
existing = ses.exec(
select(Script).where(
Script.title == key_title,
Script.creator == key_creator
)
).first()
if existing:
# Update all fields
for k, v in payload.items():
setattr(existing, k, v)
_score_and_update_compliance(existing)
existing.updated_at = datetime.utcnow()
ses.add(existing)
else:
obj = Script(**payload)
_score_and_update_compliance(obj)
ses.add(obj)
count += 1
ses.commit()
return count
# ---- Ratings API ----
def add_rating(script_id: int,
overall: float,
hook: Optional[float] = None,
originality: Optional[float] = None,
style_fit: Optional[float] = None,
safety: Optional[float] = None,
notes: Optional[str] = None,
rater: str = "human") -> None:
with get_session() as ses:
# store rating event
ses.add(Rating(
script_id=script_id, overall=overall, hook=hook,
originality=originality, style_fit=style_fit, safety=safety,
notes=notes, rater=rater
))
ses.commit()
# recompute cached aggregates on Script
_recompute_script_aggregates(ses, script_id)
ses.commit()
def _recompute_script_aggregates(ses: Session, script_id: int) -> None:
rows = list(ses.exec(select(Rating).where(Rating.script_id == script_id)))
if not rows:
return
def avg(field):
vals = [getattr(r, field) for r in rows if getattr(r, field) is not None]
return round(sum(vals)/len(vals), 3) if vals else None
s: Script = ses.get(Script, script_id)
s.score_overall = avg("overall")
s.score_hook = avg("hook")
s.score_originality = avg("originality")
s.score_style_fit = avg("style_fit")
s.score_safety = avg("safety")
s.ratings_count = len(rows)
s.updated_at = datetime.utcnow()
ses.add(s)
# ---- Public: Reference retrieval for generation ----
def extract_snippets_from_script(s: Script, max_lines: int = 3) -> List[str]:
items: List[str] = []
if s.hook:
items.append(s.hook.strip())
if s.beats:
items.extend([b.strip() for b in s.beats[:2]]) # first 1–2 beats
if s.caption:
items.append(s.caption.strip()[:120])
# dedupe while preserving order
seen, uniq = set(), []
for it in items:
if it and it not in seen:
uniq.append(it); seen.add(it)
return uniq[:max_lines]
def get_library_refs(creator: str, content_type: str, k: int = 6) -> List[str]:
with get_session() as ses:
rows = list(ses.exec(
select(Script)
.where(
Script.creator == creator,
Script.content_type == content_type,
Script.is_reference == True,
Script.compliance != "fail"
)
.order_by(Script.created_at.desc())
))[:k]
snippets: List[str] = []
for r in rows:
snippets.extend(extract_snippets_from_script(r))
# final dedupe
seen, uniq = set(), []
for s in snippets:
if s not in seen:
uniq.append(s); seen.add(s)
return uniq[:8]
# ---- HYBRID reference retrieval ----
def get_hybrid_refs(creator: str, content_type: str, k: int = 6,
top_n: int = 3, explore_n: int = 2, newest_n: int = 1) -> List[str]:
"""
Mix of:
- top_n best scored references (exploit)
- explore_n random references (explore)
- newest_n most recent references (freshness)
Returns flattened snippet list (cap ~8 to keep prompt lean).
"""
with get_session() as ses:
all_refs = list(ses.exec(
select(Script).where(
Script.creator == creator,
Script.content_type == content_type,
Script.is_reference == True,
Script.compliance != "fail"
)
))
if not all_refs:
return []
# sort by score_overall (fallback to 0) and pick top_n
scored = sorted(all_refs, key=lambda s: (s.score_overall or 0.0), reverse=True)
best = scored[:top_n]
# newest by created_at
newest = sorted(all_refs, key=lambda s: s.created_at, reverse=True)[:newest_n]
# explore = random sample from the remainder
remainder = [r for r in all_refs if r not in best and r not in newest]
explore = random.sample(remainder, min(explore_n, len(remainder))) if remainder else []
# merge (preserve order, dedupe)
chosen_scripts = []
seen_ids = set()
for bucket in (best, explore, newest):
for s in bucket:
if s.id not in seen_ids:
chosen_scripts.append(s)
seen_ids.add(s.id)
# cut to k scripts
chosen_scripts = chosen_scripts[:k]
# flatten snippets and cap to keep prompt compact
snippets: List[str] = []
for s in chosen_scripts:
snippets.extend(extract_snippets_from_script(s))
# dedupe again and cap ~8 lines
seen, out = set(), []
for sn in snippets:
if sn not in seen:
out.append(sn); seen.add(sn)
return out[:8]
|