Spaces:
Sleeping
Sleeping
| # db.py | |
| import os, json, random | |
| from contextlib import contextmanager | |
| from typing import List, Iterable, Tuple, Optional | |
| from sqlmodel import SQLModel, create_engine, Session, select | |
| from datetime import datetime | |
| # ---- Configure DB ---- | |
| DB_URL = os.environ.get("DB_URL", "sqlite:///studio.db") | |
| engine = create_engine(DB_URL, echo=False) | |
| # ---- Models ---- | |
| from models import Script, Rating # make sure Script has: is_reference: bool, plus the other fields | |
| # ---- Init / Session ---- | |
| def init_db() -> None: | |
| SQLModel.metadata.create_all(engine) | |
| def get_session(): | |
| with Session(engine) as ses: | |
| yield ses | |
| # ---- Helpers for import ---- | |
| def _payload_from_jsonl_row(row: dict) -> Tuple[dict, str, str]: | |
| """ | |
| Map a JSONL row (the file I generated for you) into Script columns. | |
| Returns (payload, dedupe_key_title, dedupe_key_creator). | |
| You can also add 'external_id' to Script model and dedupe on that. | |
| """ | |
| # Prefer using the JSON 'id' as an external identifier: | |
| external_id = row.get("id", "") | |
| # Tone could be an array; flatten for now | |
| tone = ", ".join(row.get("tonality", [])) or "playful" | |
| # Compact caption: use caption options line as a quick reference | |
| caption = " | ".join(row.get("caption_options", []))[:180] | |
| payload = dict( | |
| # core identity | |
| creator=row.get("model_name", "Unknown"), | |
| content_type=(row.get("video_type", "") or "talking_style").lower(), | |
| tone=tone, | |
| title=external_id or row.get("theme", "") or "Imported Script", | |
| hook=row.get("video_hook") or "", | |
| # structured fields | |
| beats=row.get("storyboard", []) or [], | |
| voiceover="", | |
| caption=caption, | |
| hashtags=row.get("hashtags", []) or [], | |
| cta="", | |
| # flags | |
| source="import", | |
| is_reference=True, # mark imported examples as references | |
| compliance="pass", # we'll score again after save if you want | |
| ) | |
| return payload, payload["title"], payload["creator"] | |
| def _score_and_update_compliance(s: Script) -> None: | |
| """Optional: score compliance using your simple rule-checker.""" | |
| try: | |
| from compliance import blob_from, score_script | |
| lvl, _ = score_script(blob_from(s.dict())) | |
| s.compliance = lvl | |
| except Exception: | |
| # If no compliance module or error, keep default | |
| pass | |
| def _iter_jsonl(path: str) -> Iterable[dict]: | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| yield json.loads(line) | |
| # ---- Public: Importer ---- | |
| def import_jsonl(path: str) -> int: | |
| """ | |
| Import (upsert) scripts from a JSONL file produced earlier. | |
| Dedupe by (creator, title). Returns count of upserted rows. | |
| """ | |
| init_db() | |
| count = 0 | |
| with get_session() as ses: | |
| for row in _iter_jsonl(path): | |
| payload, key_title, key_creator = _payload_from_jsonl_row(row) | |
| existing = ses.exec( | |
| select(Script).where( | |
| Script.title == key_title, | |
| Script.creator == key_creator | |
| ) | |
| ).first() | |
| if existing: | |
| # Update all fields | |
| for k, v in payload.items(): | |
| setattr(existing, k, v) | |
| _score_and_update_compliance(existing) | |
| existing.updated_at = datetime.utcnow() | |
| ses.add(existing) | |
| else: | |
| obj = Script(**payload) | |
| _score_and_update_compliance(obj) | |
| ses.add(obj) | |
| count += 1 | |
| ses.commit() | |
| return count | |
| # ---- Ratings API ---- | |
| def add_rating(script_id: int, | |
| overall: float, | |
| hook: Optional[float] = None, | |
| originality: Optional[float] = None, | |
| style_fit: Optional[float] = None, | |
| safety: Optional[float] = None, | |
| notes: Optional[str] = None, | |
| rater: str = "human") -> None: | |
| with get_session() as ses: | |
| # store rating event | |
| ses.add(Rating( | |
| script_id=script_id, overall=overall, hook=hook, | |
| originality=originality, style_fit=style_fit, safety=safety, | |
| notes=notes, rater=rater | |
| )) | |
| ses.commit() | |
| # recompute cached aggregates on Script | |
| _recompute_script_aggregates(ses, script_id) | |
| ses.commit() | |
| def _recompute_script_aggregates(ses: Session, script_id: int) -> None: | |
| rows = list(ses.exec(select(Rating).where(Rating.script_id == script_id))) | |
| if not rows: | |
| return | |
| def avg(field): | |
| vals = [getattr(r, field) for r in rows if getattr(r, field) is not None] | |
| return round(sum(vals)/len(vals), 3) if vals else None | |
| s: Script = ses.get(Script, script_id) | |
| s.score_overall = avg("overall") | |
| s.score_hook = avg("hook") | |
| s.score_originality = avg("originality") | |
| s.score_style_fit = avg("style_fit") | |
| s.score_safety = avg("safety") | |
| s.ratings_count = len(rows) | |
| s.updated_at = datetime.utcnow() | |
| ses.add(s) | |
| # ---- Public: Reference retrieval for generation ---- | |
| def extract_snippets_from_script(s: Script, max_lines: int = 3) -> List[str]: | |
| items: List[str] = [] | |
| if s.hook: | |
| items.append(s.hook.strip()) | |
| if s.beats: | |
| items.extend([b.strip() for b in s.beats[:2]]) # first 1–2 beats | |
| if s.caption: | |
| items.append(s.caption.strip()[:120]) | |
| # dedupe while preserving order | |
| seen, uniq = set(), [] | |
| for it in items: | |
| if it and it not in seen: | |
| uniq.append(it); seen.add(it) | |
| return uniq[:max_lines] | |
| def get_library_refs(creator: str, content_type: str, k: int = 6) -> List[str]: | |
| with get_session() as ses: | |
| rows = list(ses.exec( | |
| select(Script) | |
| .where( | |
| Script.creator == creator, | |
| Script.content_type == content_type, | |
| Script.is_reference == True, | |
| Script.compliance != "fail" | |
| ) | |
| .order_by(Script.created_at.desc()) | |
| ))[:k] | |
| snippets: List[str] = [] | |
| for r in rows: | |
| snippets.extend(extract_snippets_from_script(r)) | |
| # final dedupe | |
| seen, uniq = set(), [] | |
| for s in snippets: | |
| if s not in seen: | |
| uniq.append(s); seen.add(s) | |
| return uniq[:8] | |
| # ---- HYBRID reference retrieval ---- | |
| def get_hybrid_refs(creator: str, content_type: str, k: int = 6, | |
| top_n: int = 3, explore_n: int = 2, newest_n: int = 1) -> List[str]: | |
| """ | |
| Mix of: | |
| - top_n best scored references (exploit) | |
| - explore_n random references (explore) | |
| - newest_n most recent references (freshness) | |
| Returns flattened snippet list (cap ~8 to keep prompt lean). | |
| """ | |
| with get_session() as ses: | |
| all_refs = list(ses.exec( | |
| select(Script).where( | |
| Script.creator == creator, | |
| Script.content_type == content_type, | |
| Script.is_reference == True, | |
| Script.compliance != "fail" | |
| ) | |
| )) | |
| if not all_refs: | |
| return [] | |
| # sort by score_overall (fallback to 0) and pick top_n | |
| scored = sorted(all_refs, key=lambda s: (s.score_overall or 0.0), reverse=True) | |
| best = scored[:top_n] | |
| # newest by created_at | |
| newest = sorted(all_refs, key=lambda s: s.created_at, reverse=True)[:newest_n] | |
| # explore = random sample from the remainder | |
| remainder = [r for r in all_refs if r not in best and r not in newest] | |
| explore = random.sample(remainder, min(explore_n, len(remainder))) if remainder else [] | |
| # merge (preserve order, dedupe) | |
| chosen_scripts = [] | |
| seen_ids = set() | |
| for bucket in (best, explore, newest): | |
| for s in bucket: | |
| if s.id not in seen_ids: | |
| chosen_scripts.append(s) | |
| seen_ids.add(s.id) | |
| # cut to k scripts | |
| chosen_scripts = chosen_scripts[:k] | |
| # flatten snippets and cap to keep prompt compact | |
| snippets: List[str] = [] | |
| for s in chosen_scripts: | |
| snippets.extend(extract_snippets_from_script(s)) | |
| # dedupe again and cap ~8 lines | |
| seen, out = set(), [] | |
| for sn in snippets: | |
| if sn not in seen: | |
| out.append(sn); seen.add(sn) | |
| return out[:8] | |