ylmmhf's picture
Debug the last interval issue
0dbe13b verified
raw
history blame
12.1 kB
import csv
import os
import tempfile
import zipfile
from pathlib import Path
import gradio as gr
import io
import re
# ==== Core Functions ====
def create_interval_data_dict(xmin, xmax, sentence):
return {'xmin': float(xmin), 'xmax': float(xmax), 'text': sentence}
def write_textgrid_file(intervals, output_file_path, total_xmax, tier_name):
with open(output_file_path, 'w') as f:
f.write('File type = "ooTextFile"\n')
f.write('Object class = "TextGrid"\n\n')
f.write('xmin = 0\n')
f.write(f'xmax = {str(float(total_xmax))}\n')
f.write('tiers? <exists>\n')
f.write('size = 1\n')
f.write('item []:\n')
f.write(' item [1]:\n')
f.write(' class = "IntervalTier"\n')
f.write(f' name = "{tier_name}"\n')
f.write(' xmin = 0\n')
f.write(f' xmax = {str(float(total_xmax))}\n')
f.write(f' intervals: size = {len(intervals)}\n')
for idx, interval in enumerate(intervals):
f.write(f' intervals [{idx + 1}]:\n')
f.write(f' xmin = {interval["xmin"]}\n')
f.write(f' xmax = {interval["xmax"]}\n')
f.write(f' text = "{interval["text"]}"\n')
def validate_csv_format(header):
valid_headers = [
['', 'file_name', 'xmin', 'xmax', 'text', 'is_unit_start_pred'],
['file_name', 'xmin', 'xmax', 'text', 'is_unit_start_pred']
]
return header in valid_headers
def validate_row(row, header):
if len(row) < 5:
return False, "Row does not have enough columns."
if header == ['', 'file_name', 'xmin', 'xmax', 'text', 'is_unit_start_pred']:
try:
filename = row[1].strip()
xmin = float(row[2])
xmax = float(row[3])
text = row[4].strip()
is_unit_start_pred = row[5].strip().lower() in ["true", "false"]
if xmin >= xmax:
return False, "xmin must be less than xmax."
if not re.match(r'^[\p{L}\p{N}\p{P}\p{Zs}]*$', text, re.UNICODE):
return False, "Text contains invalid characters."
return True, ""
except ValueError:
return False, "Data format error (possibly number conversion failed)."
elif header == ['file_name', 'xmin', 'xmax', 'text', 'is_unit_start_pred']:
try:
filename = row[0].strip()
xmin = float(row[1])
xmax = float(row[2])
text = row[3].strip()
is_unit_start_pred = row[4].strip().lower() in ["true", "false"]
if xmin >= xmax:
return False, "xmin must be less than xmax."
if not re.match(r'^[\p{L}\p{N}\p{P}\p{Zs}]*$', text, re.UNICODE):
return False, "Text contains invalid characters."
return True, ""
except ValueError:
return False, "Data format error (possibly number conversion failed)."
return False, "Invalid header format."
# ==== Gradio Interface Function ====
def csv_to_textgrid(file, tier_name=""):
try:
temp_dir = tempfile.mkdtemp()
csv_path = os.path.join(temp_dir, "input.csv")
if isinstance(file, str):
with open(file, 'r', encoding='utf-8') as f:
file_content = f.read()
else:
try:
if hasattr(file, 'read'):
file_content = file.read().decode('utf-8', errors='replace')
else:
file_content = str(file)
except Exception as e:
print(f"Error reading file: {e}")
return None, f"Error reading file: {e}"
with open(csv_path, 'w', encoding='utf-8') as f:
f.write(file_content)
print(f"CSV file written to {csv_path}")
output_directory = os.path.join(temp_dir, "textgrids")
os.makedirs(output_directory, exist_ok=True)
processed_files = []
try:
with open(csv_path, 'r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
header = next(reader)
if not validate_csv_format(header):
return None, "Invalid CSV format. Expected headers: file_name, xmin, xmax, text, is_unit_start_pred"
print(f"Header: {header}")
iu_xmin = 0
iu_xmax = 0
intervals = []
words = []
prev_filename = None
current_file_processed = False
for row in reader:
valid, message = validate_row(row, header)
if not valid:
print(f"Skipping invalid row: {row} - {message}")
continue
try:
if header == ['', 'file_name', 'xmin', 'xmax', 'text', 'is_unit_start_pred']:
filename_idx = 1
xmin_idx = 2
xmax_idx = 3
text_idx = 4
is_unit_start_idx = 5
else:
filename_idx = 0
xmin_idx = 1
xmax_idx = 2
text_idx = 3
is_unit_start_idx = 4
filename = row[filename_idx].strip() if len(row) > filename_idx and row[filename_idx].strip() else None
if not filename:
print(f"Skipping row with no filename: {row}")
continue
xmin = float(row[xmin_idx]) if row[xmin_idx].strip() else 0
xmax = float(row[xmax_idx]) if row[xmax_idx].strip() else 0
text = row[text_idx].strip() if len(row) > text_idx else ""
is_unit_start_pred_str = row[is_unit_start_idx].strip().lower() if len(row) > is_unit_start_idx else "false"
is_unit_start_pred = is_unit_start_pred_str == "true"
print(f"Processing: {filename}, {xmin}, {xmax}, {text}, {is_unit_start_pred}")
except (ValueError, IndexError) as e:
print(f"Error processing row {row}: {e}")
continue
# Handle file transition
if prev_filename is not None and prev_filename != filename:
if words:
intervals.append(create_interval_data_dict(iu_xmin, iu_xmax, ' '.join(words)))
if intervals:
last_xmax = intervals[-1]['xmax']
intervals.append(create_interval_data_dict(last_xmax, last_xmax + 0.001, '')) # New interval
if intervals:
textgrid_path = os.path.join(output_directory, f"{prev_filename}.TextGrid")
write_textgrid_file(intervals, textgrid_path, intervals[-1]['xmax'], tier_name)
processed_files.append(prev_filename)
print(f"Wrote file: {prev_filename}.TextGrid with {len(intervals)} intervals")
iu_xmin = 0
iu_xmax = 0
intervals = []
words = []
current_file_processed = True
prev_filename = filename
current_file_processed = False
if is_unit_start_pred:
prev_xmax = intervals[-1]['xmax'] if intervals else 0
if prev_xmax != iu_xmin:
intervals.append(create_interval_data_dict(prev_xmax, iu_xmin, ''))
if words:
intervals.append(create_interval_data_dict(iu_xmin, iu_xmax, ' '.join(words)))
words = []
iu_xmin = xmin
words.append(text)
iu_xmax = xmax
# Process the last file
if not current_file_processed and prev_filename:
if words:
intervals.append(create_interval_data_dict(iu_xmin, iu_xmax, ' '.join(words)))
# Add the new interval with xmin as last xmax and xmax as last xmax + 0.001
if intervals:
last_xmax = intervals[-1]['xmax']
new_xmin = last_xmax
new_xmax = last_xmax + 0.001
# Only add the new interval if it's not a duplicate
if new_xmin < new_xmax: # Ensure they are not the same
intervals.append(create_interval_data_dict(new_xmin, new_xmax, '')) # New interval
if intervals:
textgrid_path = os.path.join(output_directory, f"{prev_filename}.TextGrid")
write_textgrid_file(intervals, textgrid_path, intervals[-1]['xmax'], tier_name)
processed_files.append(prev_filename)
print(f"Wrote last file: {prev_filename}.TextGrid with {len(intervals)} intervals")
except Exception as e:
print(f"Error processing CSV: {e}")
return None, f"Error processing CSV: {e}"
if processed_files:
zip_path = os.path.join(temp_dir, "textgrids.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for tg_file in Path(output_directory).rglob("*.TextGrid"):
zipf.write(tg_file, tg_file.name)
print(f"Added to zip: {tg_file.name}")
if os.path.exists(zip_path) and os.path.getsize(zip_path) > 0:
return zip_path, f"Successfully processed {len(processed_files)} files: {', '.join(processed_files)}"
else:
return None, "Zip file creation failed."
else:
return None, "No files were processed. Please check your CSV format."
except Exception as e:
print(f"Error in processing: {str(e)}")
return None, f"Error: {str(e)}"
# ==== Gradio Interface Setup ====
csv_format_instruction = """
**Expected CSV format:**\n
The first row is the header. Each subsequent row should contain:\n
With index: `, file_name, xmin, xmax, text, is_unit_start_pred`\n
Without index: `file_name, xmin, xmax, text, is_unit_start_pred`\n\n
- `file_name`: Identifier for the audio file (used to group intervals).\n
- `xmin`: Start time of the segment (in seconds).\n
- `xmax`: End time of the segment (in seconds).\n
- `text`: The actual spoken word or phrase (supports multiple languages).\n
- `is_unit_start_pred`: Marks the beginning of a new unit (TRUE/FALSE).\n
**Please enter your preferred tier name in the space below.**\n
Example (with index, works the same without index):\n
| | file_name | xmin | xmax | text | is_unit_start_pred |
|-|-----------|--------|--------|-------|---------------------|
|0| example1 | 20.42 | 20.74 | Hello | TRUE |
|1| example1 | 20.74 | 20.81 | World | TRUE |
|2| example1 | 20.81 | 20.92 | ! | FALSE |
"""
iface = gr.Interface(
fn=csv_to_textgrid,
inputs=[
gr.File(label="πŸ“ Upload CSV File", file_types=[".csv"]),
gr.Textbox(label="πŸ“ Enter Tier Name", placeholder="Enter the tier name")
],
outputs=[
gr.File(label="πŸ“¦ Download TextGrid ZIP"),
gr.Textbox(label="βœ… Status")
],
title="CSV2Praat Auto Tool",
description="Upload a properly formatted CSV file and receive one or more Praat TextGrid files in return. Files will be packed in a .zip for download.\n\n" + csv_format_instruction,
)
if __name__ == "__main__":
iface.launch()