feat: Add AI-driven verification and correction for speech-to-text timed transcripts.
Browse files- src/google_src/ai_studio_sdk.py +17 -0
- src/google_src/stt.py +43 -1
- src/pipelines/ai_pipeline.py +1 -1
- src/prompt/stt_verification.md +17 -0
src/google_src/ai_studio_sdk.py
CHANGED
|
@@ -328,6 +328,23 @@ def _get_mock_response(prompt: str) -> str:
|
|
| 328 |
"reason": "Pricing and CTA (\"Link in bio\") refer to the product purchase"
|
| 329 |
}
|
| 330 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
"""
|
| 332 |
|
| 333 |
# Default fallback
|
|
|
|
| 328 |
"reason": "Pricing and CTA (\"Link in bio\") refer to the product purchase"
|
| 329 |
}
|
| 330 |
]
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
# 6. Transcript Verification Prompt
|
| 334 |
+
if "verify and correct the timed words" in prompt_lower or "speech-to-text alignment" in prompt_lower:
|
| 335 |
+
return """
|
| 336 |
+
[
|
| 337 |
+
{ "word": "If", "start_time": 0.2, "end_time": 0.4, "confidence": 0.99 },
|
| 338 |
+
{ "word": "you're", "start_time": 0.4, "end_time": 0.5, "confidence": 0.99 },
|
| 339 |
+
{ "word": "creating", "start_time": 0.5, "end_time": 0.9, "confidence": 0.99 },
|
| 340 |
+
{ "word": "content", "start_time": 0.9, "end_time": 1.3, "confidence": 0.99 },
|
| 341 |
+
{ "word": "for", "start_time": 1.3, "end_time": 1.4, "confidence": 0.99 },
|
| 342 |
+
{ "word": "social", "start_time": 1.4, "end_time": 1.8, "confidence": 0.99 },
|
| 343 |
+
{ "word": "media,", "start_time": 1.8, "end_time": 2.3, "confidence": 0.99 },
|
| 344 |
+
{ "word": "you", "start_time": 2.3, "end_time": 2.4, "confidence": 0.99 },
|
| 345 |
+
{ "word": "need", "start_time": 2.4, "end_time": 2.7, "confidence": 0.99 },
|
| 346 |
+
{ "word": "b-roll", "start_time": 2.7, "end_time": 3.3, "confidence": 0.99 }
|
| 347 |
+
]
|
| 348 |
"""
|
| 349 |
|
| 350 |
# Default fallback
|
src/google_src/stt.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
from typing import List, Dict, Union
|
| 3 |
from google.cloud import speech_v1 as speech
|
| 4 |
from src.utils import logger
|
|
@@ -12,7 +13,7 @@ class GoogleSTT:
|
|
| 12 |
credentials = get_gcs_credentials("final_data")
|
| 13 |
self.client = speech.SpeechClient(credentials=credentials)
|
| 14 |
|
| 15 |
-
def generate_timed_transcript(self, audio_input: Union[str, bytes]) -> List[Dict]:
|
| 16 |
"""
|
| 17 |
Generate timed transcript using Google Cloud Speech-to-Text.
|
| 18 |
|
|
@@ -106,6 +107,47 @@ class GoogleSTT:
|
|
| 106 |
|
| 107 |
logger.info(f"✅ Generated timed transcript: {len(words)} words")
|
| 108 |
logger.debug(f"Timed Transcript:\n{json.dumps(words, indent=2)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
return words
|
| 110 |
|
| 111 |
except Exception as e:
|
|
|
|
| 1 |
import json
|
| 2 |
+
import json5
|
| 3 |
from typing import List, Dict, Union
|
| 4 |
from google.cloud import speech_v1 as speech
|
| 5 |
from src.utils import logger
|
|
|
|
| 13 |
credentials = get_gcs_credentials("final_data")
|
| 14 |
self.client = speech.SpeechClient(credentials=credentials)
|
| 15 |
|
| 16 |
+
def generate_timed_transcript(self, audio_input: Union[str, bytes], verify_with_text: str = None) -> List[Dict]:
|
| 17 |
"""
|
| 18 |
Generate timed transcript using Google Cloud Speech-to-Text.
|
| 19 |
|
|
|
|
| 107 |
|
| 108 |
logger.info(f"✅ Generated timed transcript: {len(words)} words")
|
| 109 |
logger.debug(f"Timed Transcript:\n{json.dumps(words, indent=2)}")
|
| 110 |
+
|
| 111 |
+
if verify_with_text:
|
| 112 |
+
logger.info("🔍 Verifying transcript with text...")
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
# Construct prompt for verification
|
| 116 |
+
prompt_path = os.path.join(os.path.dirname(__file__), "../prompt/stt_verification.md")
|
| 117 |
+
if os.path.exists(prompt_path):
|
| 118 |
+
with open(prompt_path, "r") as f:
|
| 119 |
+
prompt_template = f.read()
|
| 120 |
+
|
| 121 |
+
prompt = prompt_template.format(
|
| 122 |
+
verify_with_text=verify_with_text,
|
| 123 |
+
timed_words_json=json.dumps(words)
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
logger.warning(f"⚠️ Prompt file not found at {prompt_path}, skipping verification.")
|
| 127 |
+
return words
|
| 128 |
+
|
| 129 |
+
from . import ai_studio_sdk
|
| 130 |
+
response_text = ai_studio_sdk.generate(prompt)
|
| 131 |
+
|
| 132 |
+
if response_text:
|
| 133 |
+
# Clean up response if it contains markdown code blocks
|
| 134 |
+
clean_response = response_text.replace("```json", "").replace("```", "").strip()
|
| 135 |
+
corrected_words = json5.loads(clean_response)
|
| 136 |
+
|
| 137 |
+
# Basic validation
|
| 138 |
+
if isinstance(corrected_words, list) and len(corrected_words) > 0:
|
| 139 |
+
logger.info(f"✅ Verified transcript: {len(corrected_words)} words")
|
| 140 |
+
words = corrected_words
|
| 141 |
+
else:
|
| 142 |
+
logger.warning("⚠️ Verification returned invalid format, keeping original transcript.")
|
| 143 |
+
else:
|
| 144 |
+
logger.warning("⚠️ Verification failed (no response), keeping original transcript.")
|
| 145 |
+
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"⚠️ Transcript verification failed: {e}")
|
| 148 |
+
# Fallback to original words on failure
|
| 149 |
+
|
| 150 |
+
|
| 151 |
return words
|
| 152 |
|
| 153 |
except Exception as e:
|
src/pipelines/ai_pipeline.py
CHANGED
|
@@ -128,7 +128,7 @@ class AIContentAutomationBase(ContentAutomationBase):
|
|
| 128 |
|
| 129 |
# Generate timed transcript
|
| 130 |
logger.info("\n⏱️ STEP 5a: Generate Timed Transcript")
|
| 131 |
-
timed_words = self.stt.generate_timed_transcript(tts_audio_data["local_path"])
|
| 132 |
visual_assets = get_config_value("visual_assets", {})
|
| 133 |
visual_assets["timed_transcript"] = timed_words
|
| 134 |
set_config_value("visual_assets", visual_assets)
|
|
|
|
| 128 |
|
| 129 |
# Generate timed transcript
|
| 130 |
logger.info("\n⏱️ STEP 5a: Generate Timed Transcript")
|
| 131 |
+
timed_words = self.stt.generate_timed_transcript(tts_audio_data["local_path"], tts_audio_data["text"])
|
| 132 |
visual_assets = get_config_value("visual_assets", {})
|
| 133 |
visual_assets["timed_transcript"] = timed_words
|
| 134 |
set_config_value("visual_assets", visual_assets)
|
src/prompt/stt_verification.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an expert in speech-to-text alignment.
|
| 2 |
+
I have a ground truth text and a list of timed words generated by an STT engine.
|
| 3 |
+
The STT engine might have made spelling mistakes, missed words, or added extra words.
|
| 4 |
+
|
| 5 |
+
Your task is to correct the 'word' field in the timed words list to match the ground truth text,
|
| 6 |
+
while preserving the 'start_time' and 'end_time' as much as possible.
|
| 7 |
+
If a word is missing in the STT output but present in the ground truth, you can try to interpolate timings or merge/split existing timings,
|
| 8 |
+
but the most important thing is that the sequence of 'word' values in your output EXACTLY matches the ground truth text.
|
| 9 |
+
|
| 10 |
+
Ground Truth Text:
|
| 11 |
+
"{verify_with_text}"
|
| 12 |
+
|
| 13 |
+
Timed Words List (JSON):
|
| 14 |
+
{timed_words_json}
|
| 15 |
+
|
| 16 |
+
Return ONLY the corrected JSON list of objects with keys: "word", "start_time", "end_time", "confidence".
|
| 17 |
+
Do not return any markdown formatting or explanation. Just the JSON.
|