TQTune / data_prep.py
bhavinjawade's picture
Model save
9564ed2 verified
import typing as T
import os
import sys
import argparse
import json
import nflx_copilot as ncp
import pandas as pd
import re
sys.path.append("/root/workspace")
from timedtext.adapters.translation.generation.pldl import TimedTextAdapter, ConverterDialogContext
from timedtext.manager import TimedTextManager
from timedtext.handlers import OriginalLanguagePivotLanguageHandler, EnglishTemplateSubtitleHandler
from timedprompts.evaluation.pldl_prompt_one.prompt import (
ReferenceFreeFeedbackTransform,
ContextFreeFeedbackTransform,
ReferenceFreeDirectTransform,
ReferenceBasedFeedbackTransform,
ReferenceFreeExampleTransform,
)
from tqdm import tqdm
from timedtune.convert.tq_for_pldl.pldl_train_one import PldlTrainOneReferenceFreeTransform
from timedtext.adapters.translation.evaluation import compute_score_delta
def compute_32_point_score(response, generation):
parsed, score = {}, -1
try:
score = (
int(response["Accuracy Score"])
+ int(response["Readability Score"])
+ compute_score_delta(response, "Accuracy Issues", generation)
+ compute_score_delta(response, "Readability Issues", generation)
)
score = score * 4
except:
score = -1
return parsed, score
# Your existing TimedTextAdapter and helper classes
class TimedTextAdapterFromCache_PLDL(TimedTextAdapter):
def __init__(
self,
data_dir: str,
cache_size: int = 0,
ol_dialog_list_version: str = "",
pl_dialog_list_version: str = "",
ol_dialog_list_pl_dialog_list_version: str = "",
num_prev_events: int = 16,
num_next_events: int = 16,
) -> None:
super().__init__(num_prev_events, num_next_events)
self.timed_text_manager = TimedTextManager(
data_dir,
cache_size=cache_size,
ol_dialog_list_version=ol_dialog_list_version,
pl_dialog_list_version=pl_dialog_list_version,
ol_dialog_list_pl_dialog_list_version=ol_dialog_list_pl_dialog_list_version,
)
def _get_timed_text(
self, movie_id: int, start_frame: int, end_frame: int, src_lang: str, tgt_lang: str
) -> T.Dict[str, T.Union[T.Dict, T.List[T.Dict]]]:
results = self.timed_text_manager.match_and_get_timed_text(
handler_class=OriginalLanguagePivotLanguageHandler,
movie_id=movie_id,
start_frame=start_frame,
end_frame=end_frame,
src_lang=src_lang,
tgt_lang=tgt_lang,
mid_lang="",
**self.timed_text_kwargs,
)
curr_srcs = [result["curr"]["src"]["txt"] for result in results]
curr_tgts = [result["curr"]["tgt"]["txt"] for result in results]
return {
"curr": {"src": {"txt": "\n\n".join(curr_srcs)}, "tgt": {"txt": "\n\n".join(curr_tgts)}},
"prev": results[0]["prev"],
"next": results[-1]["next"],
}
class TimedTextAdapterFromCache_SUBS(TimedTextAdapter):
def __init__(
self,
data_dir: str,
cache_size: int = 0,
num_prev_events: int = 16,
num_next_events: int = 16,
) -> None:
super().__init__(num_prev_events, num_next_events)
self.timed_text_manager = TimedTextManager(
data_dir,
cache_size=cache_size,
)
def _get_timed_text(
self, movie_id: int, start_frame: int, end_frame: int, src_lang: str, tgt_lang: str
) -> T.Dict[str, T.Union[T.Dict, T.List[T.Dict]]]:
results = self.timed_text_manager.match_and_get_timed_text(
handler_class=EnglishTemplateSubtitleHandler,
movie_id=movie_id,
start_frame=start_frame,
end_frame=end_frame,
src_lang=src_lang,
tgt_lang=tgt_lang,
mid_lang="",
**self.timed_text_kwargs,
)
curr_srcs = [result["curr"]["src"]["txt"] for result in results]
curr_tgts = [result["curr"]["tgt"]["txt"] for result in results]
return {
"curr": {"src": {"txt": "\n\n".join(curr_srcs)}, "tgt": {"txt": "\n\n".join(curr_tgts)}},
"prev": results[0]["prev"],
"next": results[-1]["next"],
}
# Function to fetch contextual information using TimedTextAdapter
def fetch_contextual_information(timed_text_adapter, row):
"""
Fetches the required context information for each sample using timed_text_adapter.
Args:
timed_text_adapter (TimedTextAdapterFromCache): Adapter to fetch data from.
row (dict): Row containing the necessary information to fetch the context.
Returns:
dict: Contextual information containing src_text, tgt_text, prev_context, next_context, src_prev, src_next, tgt_prev, tgt_next.
"""
# Fetching the actual translation context
src_text, tgt_text, prev_context, next_context = timed_text_adapter.get_timed_text(
movie_id=row["movie_id"],
start_frame=row["start_frame"],
end_frame=row["end_frame"],
src_lang=row["src_lang"],
tgt_lang=row["tgt_lang"],
)
timed_text_converter = ConverterDialogContext(timed_text_adapter)
# Converting context to the format expected by the prompt
src_prev, src_next, tgt_prev, tgt_next, _ = timed_text_converter.__context__(
row["src_lang"], row["tgt_lang"], prev_context, next_context, None
)
return {
"tt_src_text": src_text,
"tt_tgt_text": tgt_text,
"tt_src_prev": src_prev,
"tt_src_next": src_next,
"tt_tgt_prev": tgt_prev,
"tt_tgt_next": tgt_next,
}
def transform_json(input_json):
# Get the first project key
project_key = list(input_json['projects'].keys())[0]
project = input_json['projects'][project_key]
final_output = {"labelers": []}
# Process each label
for index, label in enumerate(project['labels']):
# Initialize output structure
output = {
"annotation": {
"Accuracy Issues": [],
"Readability Issues": [],
"Accuracy Score": "",
"Readability Score": "",
"Confidence Level": "",
"Main Vs Alternate": "",
"Score": "-1" # initalized -1, will be updated in next steps
},
}
# Process annotations/objects (issues)
if 'objects' in label['annotations']:
for obj in label['annotations']['objects']:
issue = {
"Error Location": obj['conversational_location']['message_id'],
"Error Span": [
obj['conversational_location']['location']['start'],
obj['conversational_location']['location']['end']
],
"Error Explanation": "",
"Error Quality Category": obj['name'],
"Error Quality Tags": [],
"Error Severity": ""
}
# Process classifications within object
for classification in obj['classifications']:
if classification['name'] == 'Explanation':
issue["Error Explanation"] = classification['text_answer']['content']
elif classification['name'] == 'Quality Tag':
issue["Error Quality Tags"] = [ans['name'].lower() for ans in classification['checklist_answers']]
elif classification['name'] == 'Quality SubCategory':
severity = classification['radio_answer']['name']
if 'Major' in severity:
issue["Error Severity"] = "Major"
else:
issue["Error Severity"] = "Minor"
# Add to appropriate issues list
if obj['name'] == 'Style':
output['annotation']['Readability Issues'].append(issue)
else:
output['annotation']['Accuracy Issues'].append(issue)
# Process classifications
for classification in label['annotations']['classifications']:
if classification['name'] == 'Accuracy Score':
output['annotation']['Accuracy Score'] = classification['radio_answer']['name'].split(' - ')[0]
elif classification['name'] == 'Readability Score':
output['annotation']['Readability Score'] = classification['radio_answer']['name'].split(' - ')[0]
elif classification['name'] == 'Confidence Level':
output['annotation']['Confidence Level'] = classification['radio_answer']['value']
elif classification['name'] == 'Main vs Alternate':
output['annotation']['Main Vs Alternate'] = classification['radio_answer']['name']
final_output["labelers"].append(output)
return final_output
# Function to load the relevant meta json for a given key
def load_meta_json(priority_key, data_row_key, meta_path):
"""
Loads and validates metadata json from the specified path based on the priority key and data row key.
Args:
priority_key (str): Priority key from the label metadata.
data_row_key (str): Data row key to find the relevant file.
meta_path (str): Path to the metadata folder.
Returns:
dict: Loaded metadata.
"""
with open(os.path.join(meta_path, f'{priority_key}.json')) as fread:
meta_dict = json.load(fread)
_, movie_id, start_end_frame, _, _, _, _ = data_row_key.split('.')
start_frame, end_frame = start_end_frame.split('_')
if int(meta_dict['movie_id']) != int(movie_id):
print("Movie Ids didn't match:", int(meta_dict['movie_id']), int(movie_id), os.path.join(meta_path, f'{priority_key}.json'), data_row_key)
exit(0)
assert int(meta_dict['start_frame']) == int(start_frame)
assert int(meta_dict['end_frame']) == int(end_frame)
return meta_dict
# Main function that processes the data
def process_json(timed_text_adapter, example_row, meta_path, conv_path):
"""
Takes the full input json, converts it to the required format, and adds context using metadata.
Args:
timed_text_adapter (TimedTextAdapterFromCache): Adapter to fetch context.
example_row (dict): The full input JSON (like the example_row you provided).
meta_path (str): Path to the metadata folder to fetch meta json.
Returns:
dict: The enriched annotation format with context and annotation data.
"""
# Step 1: Convert the full input JSON to the required annotation format
annotation_result = transform_json(example_row)
# Extracting the necessary data_row_key and priority_key
data_row_key = example_row['data_row']['global_key']
priority_key = example_row['projects'][list(example_row["projects"].keys())[0]]['project_details']['priority']
annotation_result["Data_Row_Key"] = data_row_key
key = ".".join(data_row_key.split(".")[:3])
with open(conv_path + "/" + key + ".json") as file:
data = json.load(file)
annotation_result["main_tgt_text"] = data["messages"][0]["content"]
annotation_result["src_text"] = data["messages"][1]["content"]
annotation_result["alt_tgt_text"] = data["messages"][2]["content"]
# Load the metadata using the keys from the json
meta_dict = load_meta_json(priority_key, data_row_key, meta_path)
# Step 2: Add the metadata fields (e.g., title_id, start_frame, end_frame, src_lang, tgt_lang)
annotation_result.update({
"title_id": meta_dict['movie_id'],
"start_frame": meta_dict['start_frame'],
"end_frame": meta_dict['end_frame'],
"src_lang": meta_dict['src_lang'],
"tgt_lang": meta_dict['tgt_lang'],
})
# Step 3: Fetch contextual information using the given timed_text_adapter
context_info = fetch_contextual_information(timed_text_adapter, meta_dict)
annotation_result.update(context_info)
# Update error spans with actual text for each labeler
for labeler in annotation_result["labelers"]:
# Process Accuracy Issues
for issue in labeler["annotation"]["Accuracy Issues"]:
error_location = issue["Error Location"]
start, end = issue["Error Span"][0], issue["Error Span"][1]
# Get the actual text based on error location
if error_location == "src":
actual_text = annotation_result["src_text"][start:end]
else: # tgt
actual_text = annotation_result["main_tgt_text"][start:end]
# Update the error span with actual text
issue["Error Span"] = actual_text
# Process Readability Issues
for issue in labeler["annotation"]["Readability Issues"]:
error_location = issue["Error Location"]
start, end = issue["Error Span"]
# Get the actual text based on error location
if error_location == "src":
actual_text = annotation_result["src_text"][start:end]
else: # tgt
actual_text = annotation_result["main_tgt_text"][start:end]
# Update the error span with actual text
issue["Error Span"] = actual_text
return annotation_result
# Example usage
def main():
base_path = "MT_TQ/Caches/May2025/tquality.annotated.data/"
json_files = [base_path + "raw/" + f for f in os.listdir(base_path + "raw/") if f.endswith('.json')]
for json_file in tqdm(json_files):
if "calibration" in json_file:
print("Warning: Skipping Calibration Data, Remove this if you want to use Calibration data")
continue
if "PLDL" in json_file:
folder = "pldl"
timed_text_adapter = TimedTextAdapterFromCache_PLDL(
data_dir="/fsx_l10n/l10n_dse_timedtext/cache", num_prev_events=32, num_next_events=32
)
elif "SUBS" in json_file:
folder = "subs"
timed_text_adapter = TimedTextAdapterFromCache_SUBS(
data_dir="/fsx_l10n/l10n_dse_timedtext/cache", num_prev_events=32, num_next_events=32
)
else:
folder = ""
assert "invalid json file"
langs_type = json_file.split("/")[-1].split("-")[1].replace("_",".")
phase = json_file.split("/")[-1].split("-")[3]
phase_number = int(''.join(re.findall(r'\d+', phase))) if re.findall(r'\d+', phase) else None
phase_date = json_file.split("/")[-1].split("-")[4].replace(".json", "")
zzmetapath = f"/root/notebooks/MT_TQ/Caches/labelspace/tquality.zzmeta.data/{folder}/{langs_type}/phase {phase_number} - {phase_date}"
meta_path = zzmetapath + "/meta"
conv_path = zzmetapath + "/conv"
with open(json_file) as file:
data = json.load(file)
output_data = []
for data_point in tqdm(data):
annotation_result = process_json(timed_text_adapter, data_point, meta_path, conv_path)
for labeler in annotation_result["labelers"]:
_, score = compute_32_point_score(labeler["annotation"], annotation_result["main_tgt_text"])
labeler["annotation"]["Score"] = score
output_data.append(annotation_result)
with open(base_path + "parsed/" + json_file.split("/")[-1], 'w') as json_file:
json.dump({"data": output_data}, json_file, indent=4)
if __name__ == "__main__":
main()