Spaces:
Running
Running
| import os | |
| import re | |
| import json | |
| import math | |
| import random | |
| import requests | |
| import pymupdf4llm | |
| from concurrent.futures import ThreadPoolExecutor | |
| DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| FLASH_MODEL = "openai/gpt-oss-120b:fastest" | |
| PRO_MODEL = "zai-org/GLM-4.7:fastest" | |
| SAFETY_MODEL = "openai/gpt-oss-safeguard-20b" | |
| # Anchors that start before this many seconds are considered part of the | |
| # stream's "intro window": real viewers haven't processed any spoken content | |
| # yet, so Stage 2 chat for these anchors should skew toward arrival/meta | |
| # chatter rather than fully-formed topic takes. See select_intro_anchors() | |
| # and the INTRO-WINDOW HANDLING clause in stage_2_generate_all_drafts(). | |
| INTRO_WINDOW_SECONDS = 10.0 | |
| def extract_youtube_video_id(url: str) -> str: | |
| url = url.strip() | |
| if len(url) == 11 and re.match(r'^[a-zA-Z0-9_-]{11}$', url): | |
| return url | |
| patterns = [ | |
| r'(?:v=|\/v\/|embed\/|shorts\/|youtu\.be\/|\/embed\/|\/watch\?v=|\&v=)([a-zA-Z0-9_-]{11})' | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, url) | |
| if match: | |
| return match.group(1) | |
| return "" | |
| def fetch_video_metadata(video_id: str, timeout: float = 5.0) -> dict: | |
| """Fetch lightweight public metadata (title, channel/author name) for a | |
| YouTube video via the no-auth oEmbed endpoint. | |
| This is the source of the "video title / channel" context used by: | |
| - stage_2_generate_all_drafts(): to optionally gate identity-aware | |
| intro-window messages (e.g. recognizing a well-known speaker). | |
| - check_content_safety(): as extra context for the pre-pipeline | |
| safety gate. | |
| Returns {"title": ..., "author_name": ...} on success, or {} if the | |
| request fails for any reason (network error, invalid video id, 404, | |
| etc.). A failure here is non-fatal — callers must treat missing/empty | |
| metadata as "identity unknown" and behave conservatively (no guessing). | |
| """ | |
| try: | |
| resp = requests.get( | |
| "https://www.youtube.com/oembed", | |
| params={"url": f"https://www.youtube.com/watch?v={video_id}", "format": "json"}, | |
| timeout=timeout, | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return { | |
| "title": data.get("title", "") or "", | |
| "author_name": data.get("author_name", "") or "", | |
| } | |
| except Exception as e: | |
| print(f"oEmbed metadata fetch failed for video '{video_id}': {e}") | |
| return {} | |
| def format_timestamp(seconds: float) -> str: | |
| mins = int(seconds // 60) | |
| secs = int(seconds % 60) | |
| return f"[{mins:02d}:{secs:02d}]" | |
| def format_transcript_lines(transcript: list) -> str: | |
| lines = [] | |
| for entry in transcript: | |
| time_str = format_timestamp(entry["start"]) | |
| lines.append(f"{time_str} {entry['text']}") | |
| return "\n".join(lines) | |
| def clean_json_text(text: str) -> str: | |
| text = text.strip() | |
| # Remove markdown code block wraps | |
| if text.startswith("```json"): | |
| text = text[7:] | |
| elif text.startswith("```"): | |
| text = text[3:] | |
| if text.endswith("```"): | |
| text = text[:-3] | |
| return text.strip() | |
| def call_hf_router(model: str, messages: list, token: str) -> str: | |
| import time | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": model, | |
| "messages": messages, | |
| "temperature": 0.7 | |
| } | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| response = requests.post( | |
| "https://router.huggingface.co/v1/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| res_json = response.json() | |
| return res_json["choices"][0]["message"]["content"] | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise e | |
| print(f"HF API call failed (attempt {attempt+1}/{max_retries}): {e}. Retrying...") | |
| time.sleep(2 ** attempt + 1) | |
| def parse_srt(srt_text: str) -> list: | |
| entries = [] | |
| # Normalize newlines | |
| srt_text = srt_text.replace('\r\n', '\n').replace('\r', '\n') | |
| blocks = re.split(r'\n\s*\n', srt_text.strip()) | |
| for block in blocks: | |
| lines = [line.strip() for line in block.split('\n') if line.strip()] | |
| if len(lines) >= 3: | |
| time_line = lines[1] | |
| text = " ".join(lines[2:]) | |
| # Match formats: 00:00:03,320 --> 00:00:05,960 or 00:00:03.320 --> 00:00:05.960 | |
| match = re.match(r'(\d+):(\d+):(\d+)[,\.](\d+)\s*-->\s*(\d+):(\d+):(\d+)[,\.](\d+)', time_line) | |
| if match: | |
| h1, m1, s1, ms1, h2, m2, s2, ms2 = map(int, match.groups()) | |
| start_secs = h1 * 3600 + m1 * 60 + s1 + ms1 / 1000.0 | |
| end_secs = h2 * 3600 + m2 * 60 + s2 + ms2 / 1000.0 | |
| duration = end_secs - start_secs | |
| entries.append({ | |
| "text": text, | |
| "start": start_secs, | |
| "duration": duration | |
| }) | |
| return entries | |
| def parse_transcript_text(text: str) -> list: | |
| text = text.strip() | |
| if "-->" in text: | |
| try: | |
| entries = parse_srt(text) | |
| if entries: | |
| return entries | |
| except Exception: | |
| pass | |
| # Fallback to plain text paragraph segmentation | |
| paragraphs = [p.strip() for p in text.split('\n') if p.strip()] | |
| entries = [] | |
| current_time = 0.0 | |
| for p in paragraphs: | |
| words = len(p.split()) | |
| duration = max(3.0, min(15.0, words / 3.0)) | |
| entries.append({ | |
| "text": p, | |
| "start": current_time, | |
| "duration": duration | |
| }) | |
| current_time += duration + 1.0 | |
| return entries | |
| # --------------------------------------------------------------------------- | |
| # Anchor index utilities (deterministic, no LLM) | |
| # --------------------------------------------------------------------------- | |
| # Anchor IDs use the format "anc_N" where N is the chronological index. | |
| # validate_stage2_output relies on this naming to compare anchor ordering. | |
| MIN_ANCHOR_DURATION = 3.0 # seconds | |
| MAX_ANCHOR_DURATION = 10.0 # seconds | |
| def build_anchor_index(transcript_entries: list) -> list: | |
| """Chunk raw transcript entries into sentence/clause-sized anchors. | |
| Handles overlapping auto-caption timestamps by clamping each entry's | |
| effective start to max(entry.start, previous_effective_end). Produces | |
| a gap-free, non-overlapping, chronologically ordered list of anchors: | |
| {anchor_id, start, end, text} | |
| """ | |
| if not transcript_entries: | |
| return [] | |
| # Sort by start time and resolve overlaps. | |
| sorted_entries = sorted(transcript_entries, key=lambda e: e["start"]) | |
| resolved = [] | |
| prev_end = 0.0 | |
| for entry in sorted_entries: | |
| eff_start = max(float(entry["start"]), prev_end) | |
| eff_end = eff_start + max(float(entry["duration"]), 0.001) | |
| resolved.append({"text": entry["text"], "start": eff_start, "end": eff_end}) | |
| prev_end = eff_end | |
| anchors = [] | |
| chunk_start = resolved[0]["start"] | |
| chunk_texts = [] | |
| chunk_end = chunk_start | |
| for entry in resolved: | |
| proposed_end = entry["end"] | |
| proposed_duration = proposed_end - chunk_start | |
| # Commit current chunk if adding this entry would exceed the cap | |
| # AND the chunk is already at minimum viable duration. | |
| if chunk_texts and (chunk_end - chunk_start) >= MIN_ANCHOR_DURATION and proposed_duration > MAX_ANCHOR_DURATION: | |
| anchors.append({ | |
| "anchor_id": f"anc_{len(anchors)}", | |
| "start": chunk_start, | |
| "end": chunk_end, | |
| "text": " ".join(chunk_texts), | |
| }) | |
| chunk_start = entry["start"] | |
| chunk_texts = [entry["text"]] | |
| chunk_end = entry["end"] | |
| else: | |
| chunk_texts.append(entry["text"]) | |
| chunk_end = entry["end"] | |
| # Emit the final (possibly short) remainder chunk. | |
| if chunk_texts: | |
| anchors.append({ | |
| "anchor_id": f"anc_{len(anchors)}", | |
| "start": chunk_start, | |
| "end": chunk_end, | |
| "text": " ".join(chunk_texts), | |
| }) | |
| return anchors | |
| def select_intro_anchors(anchors: list, window_seconds: float = INTRO_WINDOW_SECONDS) -> set: | |
| """Return the set of anchor_ids whose start falls within the stream's | |
| "intro window" (the first `window_seconds` seconds of the video). | |
| Pure and deterministic: anchor-list-in, anchor-id-set-out. No LLM calls, | |
| no I/O. At the default 5-10s anchor size from build_anchor_index(), a | |
| 20s window typically covers the first 2-5 anchors. | |
| Used by stage_2_generate_all_drafts() to bias early chat toward | |
| arrival/meta chatter instead of fully-formed topic reactions, since real | |
| viewers haven't processed any spoken content in the first ~20 seconds. | |
| """ | |
| return {a["anchor_id"] for a in anchors if float(a["start"]) < window_seconds} | |
| def map_anchors_to_segments(anchors: list, segments: list) -> dict: | |
| """Assign each anchor to exactly one topic segment. | |
| Assignment is based on each anchor's midpoint. If the midpoint falls | |
| outside every segment (e.g. a coverage gap in Stage 1a output), the | |
| anchor is assigned to the segment with the nearest midpoint. | |
| Returns {segment_index: [anchor_id, ...]} for all segment indices. | |
| """ | |
| # Pre-seed every segment index with an empty list. | |
| mapping = {i: [] for i in range(len(segments))} | |
| for anchor in anchors: | |
| midpoint = (anchor["start"] + anchor["end"]) / 2.0 | |
| # Find the segment whose range contains the midpoint. | |
| assigned = None | |
| for i, seg in enumerate(segments): | |
| if float(seg["start"]) <= midpoint < float(seg["end"]): | |
| assigned = i | |
| break | |
| # Fallback: nearest segment by midpoint distance. | |
| if assigned is None: | |
| def seg_mid(s): | |
| return (float(s["start"]) + float(s["end"])) / 2.0 | |
| assigned = min(range(len(segments)), key=lambda i: abs(seg_mid(segments[i]) - midpoint)) | |
| mapping[assigned].append(anchor["anchor_id"]) | |
| return mapping | |
| def validate_stage2_output(payload: list, valid_anchor_ids: set) -> list: | |
| """Validate the LLM's Stage 2 output against the temporal-alignment contract. | |
| Returns a list of error strings (empty = valid). Does not raise. | |
| Checks: | |
| - Every message has id, username, text, anchor_id. | |
| - anchor_id is in valid_anchor_ids. | |
| - Message ids are unique across the entire payload. | |
| - reply_to (if present) references a known message id. | |
| - reply_to anchor index <= replier anchor index (no replying to the future). | |
| - No cycles in the reply graph (same-anchor or cross-anchor). | |
| """ | |
| errors = [] | |
| all_messages = [] | |
| for seg in payload: | |
| all_messages.extend(seg.get("messages", [])) | |
| # Build id -> message index and check for duplicates. | |
| id_to_msg = {} | |
| for msg in all_messages: | |
| mid = msg.get("id") | |
| if mid is None: | |
| errors.append(f"Message missing 'id' field: {msg.get('text', '')[:40]!r}") | |
| continue | |
| if mid in id_to_msg: | |
| errors.append(f"Duplicate message id: {mid!r}") | |
| else: | |
| id_to_msg[mid] = msg | |
| def anchor_index(aid: str) -> int: | |
| """Parse integer index from 'anc_N' format; return large sentinel on failure.""" | |
| try: | |
| return int(aid.split("_", 1)[1]) | |
| except (IndexError, ValueError): | |
| return 10 ** 9 | |
| for msg in all_messages: | |
| mid = msg.get("id") | |
| aid = msg.get("anchor_id") | |
| # Required fields. | |
| if aid is None: | |
| errors.append(f"Message {mid!r} missing 'anchor_id'") | |
| elif aid not in valid_anchor_ids: | |
| errors.append(f"Message {mid!r} references unknown anchor_id {aid!r}") | |
| # reply_to validation. | |
| reply_to = msg.get("reply_to") | |
| if reply_to is not None: | |
| if reply_to not in id_to_msg: | |
| errors.append(f"Message {mid!r} has reply_to {reply_to!r} which does not exist") | |
| elif aid is not None: | |
| parent = id_to_msg[reply_to] | |
| parent_aid = parent.get("anchor_id") | |
| if parent_aid is not None and anchor_index(parent_aid) > anchor_index(aid): | |
| errors.append( | |
| f"Message {mid!r} (anchor {aid}) has reply_to {reply_to!r} " | |
| f"(anchor {parent_aid}) which is later in the video" | |
| ) | |
| # Cycle detection via DFS on the full reply graph. | |
| reply_graph = {msg["id"]: msg.get("reply_to") for msg in all_messages if "id" in msg} | |
| for start_id in reply_graph: | |
| visited = set() | |
| node = start_id | |
| while node is not None: | |
| if node in visited: | |
| errors.append(f"Cycle detected in reply_to chain starting at {start_id!r}") | |
| break | |
| visited.add(node) | |
| node = reply_graph.get(node) # None if no reply_to or id not in graph | |
| return errors | |
| def compute_display_times(messages: list, anchors_by_id: dict) -> list: | |
| """Assign a numeric displayTime to every message. | |
| Root messages: anchor.start + uniform(1, min(8, anchor_duration)). | |
| Replies: max(parent.displayTime + uniform(2, 12), own_anchor.start). | |
| The jitter upper bound for roots is clamped to anchor_duration so a | |
| reaction can't land past its anchor's end time. | |
| Processes in topological order (roots first). Malformed cycles or | |
| permanently-unresolvable replies are demoted to root status after a | |
| bounded number of passes. | |
| Returns a new list of message dicts with 'displayTime' added. | |
| """ | |
| # Work on shallow copies to avoid mutating the input. | |
| result = [dict(m) for m in messages] | |
| resolved = {} # id -> displayTime | |
| max_passes = len(result) + 1 # bound iteration to prevent infinite loops | |
| remaining = list(result) | |
| for _ in range(max_passes): | |
| if not remaining: | |
| break | |
| still_waiting = [] | |
| for msg in remaining: | |
| mid = msg.get("id") | |
| anchor = anchors_by_id.get(msg.get("anchor_id", "")) | |
| if anchor is None: | |
| # Unknown anchor — assign a safe fallback of 0. | |
| msg["displayTime"] = 0.0 | |
| if mid: | |
| resolved[mid] = 0.0 | |
| continue | |
| reply_to = msg.get("reply_to") | |
| if reply_to and reply_to not in resolved: | |
| # Parent not yet placed; come back to this message. | |
| still_waiting.append(msg) | |
| continue | |
| anchor_dur = anchor["end"] - anchor["start"] | |
| if reply_to and reply_to in resolved: | |
| jitter = random.uniform(2, 12) | |
| dt = max(resolved[reply_to] + jitter, anchor["start"]) | |
| else: | |
| # Root message — jitter capped to anchor duration. | |
| hi = min(8.0, anchor_dur) | |
| if anchor.get("anchor_id") == "anc_0": | |
| jitter = random.uniform(0.5, max(1.0, hi)) | |
| else: | |
| jitter = random.uniform(2.0, max(2.0, hi)) | |
| dt = anchor["start"] + jitter | |
| msg["displayTime"] = dt | |
| if mid: | |
| resolved[mid] = dt | |
| remaining = still_waiting | |
| # Any messages still unresolved (cycles) get demoted to root placement. | |
| for msg in remaining: | |
| anchor = anchors_by_id.get(msg.get("anchor_id", "")) | |
| if anchor: | |
| anchor_dur = anchor["end"] - anchor["start"] | |
| hi = min(8.0, anchor_dur) | |
| msg["displayTime"] = anchor["start"] + random.uniform(1, max(1.0, hi)) | |
| else: | |
| msg["displayTime"] = 0.0 | |
| print(f"WARNING: compute_display_times: demoted {msg.get('id')!r} to root (unresolvable reply chain)") | |
| return result | |
| # --- STAGE 1a: Segment transcript --- | |
| def stage_1a_segment_transcript(transcript_text: str, token: str) -> list: | |
| system_prompt = ( | |
| "You are an AI assistant that segments a video transcript into logical topic segments. " | |
| "You are given a transcript formatted with timestamps.\n" | |
| "Your task is to group consecutive lines into segments of about 30 to 60 seconds (but align with natural topic boundaries).\n" | |
| "For each segment, output:\n" | |
| "- start: the start time in seconds (float or int)\n" | |
| "- end: the end time in seconds (float or int)\n" | |
| "- text: the concatenated text in this segment\n\n" | |
| "Return ONLY a JSON list of segments. Do not include markdown wraps or conversational text outside the JSON." | |
| ) | |
| user_prompt = f"Transcript:\n{transcript_text}" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| content = call_hf_router(FLASH_MODEL, messages, token) | |
| cleaned = clean_json_text(content) | |
| return json.loads(cleaned) | |
| # --- STAGE 2: Generate all draft comments (Pro model) --- | |
| def stage_2_generate_all_drafts(segments: list, doc_text: str, token: str, anchor_map: dict = None, | |
| all_anchors: list = None, intro_anchor_ids: set = None, | |
| video_metadata: dict = None) -> list: | |
| """Generate draft chat messages for all segments. | |
| anchor_map: {seg_idx: [anchor_id, ...]} from map_anchors_to_segments. | |
| all_anchors: full anchor list from build_anchor_index (for text lookup). | |
| When both are provided, each segment's anchor list is injected into the | |
| prompt so the model can pick a concrete timestamped moment per message. | |
| intro_anchor_ids: set of anchor_ids from select_intro_anchors(), i.e. | |
| anchors within the stream's first ~20s. These are tagged [INTRO] in the | |
| anchor lists and given special handling instructions (see | |
| INTRO-WINDOW HANDLING below). | |
| video_metadata: optional {"title": ..., "author_name": ...} from | |
| fetch_video_metadata(). When present with a non-empty title/channel, the | |
| model may (optionally) generate up to 1-2 identity-aware intro messages | |
| if the speaker's identity is unambiguous; otherwise it must not guess. | |
| """ | |
| # Build anchor lookup if available. | |
| anchors_by_id = {a["anchor_id"]: a for a in (all_anchors or [])} | |
| intro_anchor_ids = intro_anchor_ids or set() | |
| system_prompt = ( | |
| "You are simulating audience chat reactions for a livestream of an educational or historical video.\n" | |
| "You are given:\n" | |
| "1. A reference document.\n" | |
| "2. A chronological list of video segments. Each segment lists its ANCHOR MOMENTS: timestamped transcript" | |
| " chunks at the sentence/clause level.\n\n" | |
| "Your task is to generate 8 to 15 draft chat messages from different users reacting to EACH video segment.\n\n" | |
| "Crucially, you must follow these steps:\n" | |
| "1. Identify the conceptual relationships (e.g., direct validations, unintentional contradictions, theoretical bridges) between the video and the document.\n" | |
| "2. Frame reactions to each caption segment by drawing on the identified conceptual relationships without direct reference or quotes. AVOID acadmeic jargon.\n" | |
| "3. Generate chat messages for each segment. Every message MUST be anchored to a specific moment from that segment's anchor list.\n\n" | |
| "QUOTAS AND PERSONAS (STRICTLY ENFORCED):\n" | |
| "- Include at least 50% on-topic contributions reflecting agreement, disagreement, jokes, observations, predictions, corrections, hype and skepticism to complement occasional off-topic remarks, each with an anchor_id.\n" | |
| "- Ensure diverse, non-repetitive usernames across the entire video. Do not use the same usernames repeatedly for the same types of comments.\n" | |
| "- Maintain diverse livestream audience personas: some are experts reading deeply into philosophical tension, some take things entirely at face value, some only react emotionally or to video aesthetics, some use sarcasm/memes.\n\n" | |
| "REPLY THREADING (optional):\n" | |
| "- A message MAY include a 'reply_to' field containing another message's 'id' from the same or immediately preceding anchor.\n" | |
| "- A reply must NEVER reference a message anchored to a later anchor than itself.\n" | |
| ) | |
| if intro_anchor_ids: | |
| title = (video_metadata or {}).get("title", "").strip() | |
| author = (video_metadata or {}).get("author_name", "").strip() | |
| if title or author: | |
| channel_part = f" from channel \"{author}\"" if author else "" | |
| identity_clause = ( | |
| f"- IDENTITY-AWARE INTRO MESSAGES (OPTIONAL): The video title is \"{title or 'unknown'}\"{channel_part}. " | |
| "If — and only if — this title, channel, or the reference document make the speaker's " | |
| "identity or significance unambiguous, you MAY include up to 1-2 [INTRO]-anchored " | |
| "messages reacting to who the speaker is or why they matter (e.g. recognizing a " | |
| "well-known figure). Do not guess or assert identity beyond what these sources support.\n" | |
| ) | |
| else: | |
| identity_clause = ( | |
| "- IDENTITY-AWARE INTRO MESSAGES: No reliable video title/channel metadata is available. " | |
| "Do NOT include any [INTRO]-anchored messages that guess or assert who the speaker is — " | |
| "keep all [INTRO] messages generic arrival/meta chatter.\n" | |
| ) | |
| system_prompt += ( | |
| "\nINTRO-WINDOW HANDLING:\n" | |
| "- Anchors marked [INTRO] in the ANCHOR MOMENTS lists below fall within the stream's first " | |
| "~20 seconds. Real viewers haven't processed any spoken content yet at this point — they're " | |
| "still arriving, reading the title, or reacting to who/what the speaker is sharing.\n" | |
| "- For messages anchored to an [INTRO] anchor, 80-90% should be arrival/meta chatter NOT engagement " | |
| "with the spoken content. This OVERRIDES the normal off-topic quota for these messages.\n" | |
| "- Set '_internal_logic' to \"None\" for any segment whose messages are dominated by " | |
| "[INTRO] anchors, regardless of any document tie.\n" | |
| f"{identity_clause}" | |
| ) | |
| system_prompt += ( | |
| "\nFORMAT: Return a JSON list of objects, one per segment:\n" | |
| "{\n" | |
| " \"_internal_logic\": \"How this segment relates to document sub-claims, or 'None' if off-topic.\",\n" | |
| " \"messages\": [\n" | |
| " {\n" | |
| " \"id\": \"<unique string, e.g. m1>\",\n" | |
| " \"username\": \"<username>\",\n" | |
| " \"text\": \"<message text>\",\n" | |
| " \"anchor_id\": \"<anchor_id from this segment's anchor list>\",\n" | |
| " \"reply_to\": \"<optional: another message's id>\"\n" | |
| " }\n" | |
| " ]\n" | |
| "}\n\n" | |
| "CRITICAL: Do not include 'timestamp'. Do not omit 'anchor_id' or 'id'. " | |
| "Return ONLY a valid JSON list. Do not include markdown wraps or other text." | |
| ) | |
| def format_segment(i, seg): | |
| seg_header = f"Segment {i} ({seg['start']}s - {seg['end']}s):\n{seg['text']}" | |
| if anchor_map and all_anchors: | |
| anchor_ids = anchor_map.get(i, []) | |
| anchor_lines = [] | |
| for aid in anchor_ids: | |
| a = anchors_by_id.get(aid) | |
| if a: | |
| intro_tag = " [INTRO]" if aid in intro_anchor_ids else "" | |
| anchor_lines.append(f" [{aid}]{intro_tag} {a['start']:.1f}s: {a['text'][:80]}") | |
| if anchor_lines: | |
| seg_header += "\nANCHOR MOMENTS (pick one per message):\n" + "\n".join(anchor_lines) | |
| return seg_header | |
| segments_text = "\n\n".join([format_segment(i, seg) for i, seg in enumerate(segments)]) | |
| user_prompt = ( | |
| f"Reference Document:\n{doc_text}\n\n" | |
| f"Video Segments:\n{segments_text}\n\n" | |
| "Generate the JSON list of draft comments for all segments." | |
| ) | |
| llm_messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| content = call_hf_router(PRO_MODEL, llm_messages, token) | |
| cleaned = clean_json_text(content) | |
| try: | |
| data = json.loads(cleaned) | |
| return data | |
| except Exception as e: | |
| print(f"Failed to parse JSON from Stage 2. Raw content: {content}") | |
| raise e | |
| # --- STAGE 3: Style and pacing (Flash model) --- | |
| def stage_3_stylize_segment(draft_data: dict, token: str) -> dict: | |
| system_prompt = ( | |
| "You are a style polisher for livestream chat replays (YouTube/Twitch).\n" | |
| "Your job is to take raw draft chat messages and perform a final flourish and alignment pass to make them sound authentic.\n\n" | |
| "CRITICAL INSTRUCTIONS:\n" | |
| "1. PRESERVE DIVERSITY: The draft already contains carefully balanced personas (jokers, experts, off-topic, skeptics). DO NOT homogenize them. If a message is off-topic, keep it off-topic. If it's a joke about the video, keep it a joke.\n" | |
| "2. PRESERVE USERNAMES: You MUST use the exact usernames provided in the draft. Do not invent new ones.\n" | |
| "3. PRESERVE STRUCTURAL FIELDS: Keep 'id', 'anchor_id', and 'reply_to' (if present) on every message exactly as given. Do not rename, remove, or alter these fields.\n" | |
| "4. ADD FLOURISH: Make them short, concise, and lively. Occasionally inject internet slang and standard emotes where appropriate, but don't overdo it.\n" | |
| "5. Avoid sounding like AI-generated summaries. Do not append emotes to every single message.\n" | |
| "6. SAFETY PASS: If any message crosses from edgy/sarcastic banter into harassment, slurs, hate speech, or denigration of real people or groups, rewrite it to keep the same persona and sentiment but remove the harmful element.\n\n" | |
| "Return ONLY the updated JSON with the exact same structure. Do not include markdown wraps." | |
| ) | |
| user_prompt = f"Draft JSON:\n{json.dumps(draft_data)}" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| content = call_hf_router(FLASH_MODEL, messages, token) | |
| cleaned = clean_json_text(content) | |
| return json.loads(cleaned) | |
| # --- Safety gates --- | |
| def check_content_safety(doc_text: str, video_metadata: dict, token: str) -> tuple: | |
| """Pre-pipeline hard-stop gate (runs before Stage 2's expensive Pro call). | |
| Classifies whether the reference document — combined with whatever video | |
| title/channel context is available — is appropriate source material for | |
| a simulated livestream chat. Returns (is_safe: bool, reason: str). | |
| Intentionally coarse: a single Flash-model classification call, not a | |
| per-message filter (see final_safety_scan() for the post-generation | |
| pass). Exists to hard-stop on inputs whose "natural" simulated chat would | |
| likely be hate speech, harassment of real people/groups, glorification of | |
| or instructions for violence/self-harm, or sexual content involving | |
| minors. General controversial-but-legitimate material (politics, science | |
| controversies, history of atrocities discussed academically, etc.) is | |
| SAFE and should pass. | |
| Design note: fails OPEN (treats classifier errors as "safe") so a | |
| transient API hiccup doesn't block legitimate users. The decision (or | |
| failure) is always logged. | |
| """ | |
| title = (video_metadata or {}).get("title", "").strip() | |
| author = (video_metadata or {}).get("author_name", "").strip() | |
| system_prompt = ( | |
| "You are a content-safety gate for an app that generates SIMULATED livestream chat " | |
| "reactions to educational/informational videos, based on a reference document.\n" | |
| "Given a video's title/channel (if known) and an excerpt of the reference document, " | |
| "decide whether this is appropriate source material for that simulation.\n\n" | |
| "Mark unsafe ONLY if the natural simulated chat reactions to this material would likely " | |
| "include hate speech, harassment of real people or groups, glorification of or " | |
| "instructions for violence or self-harm, or sexual content involving minors. General " | |
| "controversial-but-legitimate topics (politics, policy debates, science controversies, " | |
| "history of atrocities discussed academically, etc.) are SAFE.\n\n" | |
| "Respond with ONLY a JSON object: {\"safe\": true|false, \"reason\": \"<one short sentence>\"}. " | |
| "Do not include markdown wraps or other text." | |
| ) | |
| excerpt = (doc_text or "")[:4000] | |
| user_prompt = ( | |
| f"Video title: {title or '(unknown)'}\n" | |
| f"Channel: {author or '(unknown)'}\n\n" | |
| f"Reference document excerpt:\n{excerpt}" | |
| ) | |
| try: | |
| content = call_hf_router(SAFETY_MODEL, [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], token) | |
| result = json.loads(clean_json_text(content)) | |
| is_safe = bool(result.get("safe", True)) | |
| reason = result.get("reason", "") | |
| if not is_safe: | |
| print(f"Content safety gate REJECTED input: {reason}") | |
| return is_safe, reason | |
| except Exception as e: | |
| print(f"Content safety gate failed to run ({e}); failing open (treating as safe).") | |
| return True, "" | |
| def final_safety_scan(messages: list, token: str) -> list: | |
| """Aggregate post-generation moderation pass over all flattened, | |
| stylized messages. | |
| Stage 3 already asks the Flash model to soften individual messages (see | |
| its SAFETY PASS instruction), but that runs per-segment and can miss | |
| patterns only visible across the whole chat (e.g. repeated harassment of | |
| the same target spread across segments). This is one additional Flash | |
| call over the full message list as a second net. | |
| For each message, the model returns one of: | |
| - "keep" -> message is left unchanged. | |
| - "replace" -> message['text'] is swapped for the supplied 'replacement' | |
| (same persona/sentiment, harmful element removed). | |
| - "drop" -> message is removed from the output entirely. | |
| Fails OPEN (returns `messages` unchanged) on any classifier or parsing | |
| error, since this is a best-effort secondary net rather than the primary | |
| gate (see check_content_safety for the pre-pipeline hard stop). | |
| """ | |
| if not messages: | |
| return messages | |
| system_prompt = ( | |
| "You are a final moderation pass for SIMULATED livestream chat messages (fictional " | |
| "audience reactions, not real users).\n" | |
| "For EVERY message below, decide one action:\n" | |
| "- \"keep\": fine as-is. Edgy/sarcastic humor, strong opinions, and in-group banter are fine.\n" | |
| "- \"replace\": the message crosses into harassment, slurs, hate speech, or denigration of " | |
| "real people or groups. Provide a 'replacement' string with the same persona/sentiment but " | |
| "with the harmful element removed.\n" | |
| "- \"drop\": the message is irredeemable and should be removed entirely.\n\n" | |
| "Return ONLY a JSON list, one entry per input message, in any order:\n" | |
| "[{\"id\": \"<id>\", \"action\": \"keep|replace|drop\", \"replacement\": \"<only if action=replace>\"}]\n" | |
| "Do not include markdown wraps or other text." | |
| ) | |
| payload = [{"id": m.get("id"), "text": m.get("text", "")} for m in messages] | |
| user_prompt = f"Messages:\n{json.dumps(payload, ensure_ascii=False)}" | |
| try: | |
| content = call_hf_router(SAFETY_MODEL, [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], token) | |
| decisions = json.loads(clean_json_text(content)) | |
| decisions_by_id = {d.get("id"): d for d in decisions if isinstance(d, dict)} | |
| except Exception as e: | |
| print(f"Final safety scan failed to run ({e}); failing open (returning messages unchanged).") | |
| return messages | |
| result = [] | |
| dropped, replaced = 0, 0 | |
| for msg in messages: | |
| decision = decisions_by_id.get(msg.get("id")) | |
| action = (decision or {}).get("action", "keep") | |
| if action == "drop": | |
| dropped += 1 | |
| continue | |
| if action == "replace" and decision.get("replacement"): | |
| msg = dict(msg) | |
| msg["text"] = decision["replacement"] | |
| replaced += 1 | |
| result.append(msg) | |
| if dropped or replaced: | |
| print(f"Final safety scan: dropped {dropped}, replaced {replaced} of {len(messages)} messages.") | |
| return result | |
| # --- Parallel Tasks --- | |
| def _fetch_and_segment_transcript(video_id: str, transcript_text: str, token: str) -> tuple: | |
| """Return (segments, raw_transcript) so the caller can build the anchor index.""" | |
| if not transcript_text: | |
| raise ValueError("Transcript text is required but was not provided.") | |
| print("Parsing provided transcript text...") | |
| raw_transcript = parse_transcript_text(transcript_text) | |
| transcript_text_formatted = format_transcript_lines(raw_transcript) | |
| print("Stage 1a: Segmenting transcript...") | |
| segments = stage_1a_segment_transcript(transcript_text_formatted, token) | |
| return segments, raw_transcript | |
| def _extract_document_text(doc_text: str, doc_path: str, use_ocr: bool = False) -> str: | |
| if doc_path: | |
| print(f"Extracting text from PDF: {doc_path}...") | |
| return pymupdf4llm.to_markdown(doc_path, use_ocr=use_ocr) | |
| if doc_text: | |
| return doc_text | |
| raise ValueError("Either doc_text or doc_path must be provided.") | |
| # --- Full Pipeline Orchestration --- | |
| def run_livestream_pipeline(video_id: str, doc_text: str = None, doc_path: str = None, transcript_text: str = None, token: str = None, use_ocr: bool = False) -> list: | |
| """Run the full pipeline and return a flat list of messages sorted by displayTime. | |
| Return value shape changed from list[segment_dict] to list[message_dict]. | |
| Each message has: id, username, text, anchor_id, displayTime, | |
| and optionally reply_to. | |
| """ | |
| if not token: | |
| token = os.environ.get("HF_TOKEN", DEFAULT_HF_TOKEN) | |
| if not token: | |
| raise ValueError( | |
| "Hugging Face API Token not found. Please set the 'HF_TOKEN' secret in your Space settings " | |
| "or provide it in the input box." | |
| ) | |
| print("Starting parallel execution of Document Extraction, Transcript Segmentation, and Metadata Fetch...") | |
| with ThreadPoolExecutor(max_workers=3) as executor: | |
| fut_doc = executor.submit(_extract_document_text, doc_text, doc_path, use_ocr) | |
| fut_seg = executor.submit(_fetch_and_segment_transcript, video_id, transcript_text, token) | |
| fut_meta = executor.submit(fetch_video_metadata, video_id) | |
| extracted_doc_text = fut_doc.result() | |
| segments, raw_transcript = fut_seg.result() | |
| video_metadata = fut_meta.result() | |
| print(f"Segmented into {len(segments)} blocks.") | |
| if video_metadata: | |
| print(f"Video metadata: title={video_metadata.get('title')!r}, channel={video_metadata.get('author_name')!r}") | |
| else: | |
| print("Video metadata unavailable (oEmbed fetch failed) — identity-aware intro messages disabled.") | |
| # Build anchor index deterministically from raw transcript (no LLM call). | |
| anchors = build_anchor_index(raw_transcript) | |
| anchor_map = map_anchors_to_segments(anchors, segments) | |
| anchors_by_id = {a["anchor_id"]: a for a in anchors} | |
| valid_anchor_ids = set(anchors_by_id.keys()) | |
| intro_anchor_ids = select_intro_anchors(anchors) | |
| print(f"Built anchor index: {len(anchors)} anchors ({len(intro_anchor_ids)} in the intro window).") | |
| # Hard stop: gate on content safety before the expensive Stage 2/3 calls. | |
| is_safe, reason = check_content_safety(extracted_doc_text, video_metadata, token) | |
| if not is_safe: | |
| raise ValueError( | |
| f"Content safety check failed: {reason} " | |
| "This input was flagged as unsuitable for simulated chat generation. " | |
| "Please choose a different reference document or video." | |
| ) | |
| # Stage 2: Single Pro model call for all drafting. | |
| print("Stage 2: Generating draft comments for all segments (Pro model)...") | |
| draft_segments = stage_2_generate_all_drafts( | |
| segments, extracted_doc_text, token, | |
| anchor_map=anchor_map, all_anchors=anchors, | |
| intro_anchor_ids=intro_anchor_ids, video_metadata=video_metadata | |
| ) | |
| # Validate Stage 2 output before passing downstream. | |
| validation_errors = validate_stage2_output(draft_segments, valid_anchor_ids) | |
| if validation_errors: | |
| # Log and raise — malformed anchor/reply data must not flow into Stage 3. | |
| error_summary = "\n".join(validation_errors[:10]) | |
| raise ValueError(f"Stage 2 output failed validation ({len(validation_errors)} error(s)):\n{error_summary}") | |
| # Stage 3: Parallel stylization. | |
| print("Stage 3: Stylizing comments...") | |
| final_segments = [] | |
| with ThreadPoolExecutor(max_workers=5) as executor: | |
| futures = [ | |
| executor.submit(stage_3_stylize_segment, draft, token) | |
| for draft in draft_segments | |
| ] | |
| for fut in futures: | |
| final_segments.append(fut.result()) | |
| # Flatten all messages. | |
| all_messages = [] | |
| for seg in final_segments: | |
| all_messages.extend(seg.get("messages", [])) | |
| # Final aggregate moderation pass (second net, after Stage 3's per-segment pass). | |
| print("Running final safety scan over all messages...") | |
| all_messages = final_safety_scan(all_messages, token) | |
| # Compute displayTime, sort numerically. | |
| all_messages = compute_display_times(all_messages, anchors_by_id) | |
| all_messages.sort(key=lambda m: m.get("displayTime", 0)) | |
| return all_messages | |