zyshan-ds's picture
Modified the description
eefbbb4 verified
import csv
import os
import tempfile
import zipfile
from pathlib import Path
import gradio as gr
import io
import re
# ==== Core Functions ====
def create_interval_data_dict(xmin, xmax, sentence):
return {'xmin': float(xmin), 'xmax': float(xmax), 'text': sentence}
def write_textgrid_file(intervals, output_file_path, total_xmax, tier_name):
with open(output_file_path, 'w') as f:
f.write('File type = "ooTextFile"\n')
f.write('Object class = "TextGrid"\n\n')
f.write('xmin = 0\n')
f.write(f'xmax = {str(float(total_xmax))}\n')
f.write('tiers? <exists>\n')
f.write('size = 1\n')
f.write('item []:\n')
f.write(' item [1]:\n')
f.write(' class = "IntervalTier"\n')
f.write(f' name = "{tier_name}"\n')
f.write(' xmin = 0\n')
f.write(f' xmax = {str(float(total_xmax))}\n')
f.write(f' intervals: size = {len(intervals)}\n')
for idx, interval in enumerate(intervals):
f.write(f' intervals [{idx + 1}]:\n')
f.write(f' xmin = {interval["xmin"]}\n')
f.write(f' xmax = {interval["xmax"]}\n')
f.write(f' text = "{interval["text"]}"\n')
def validate_csv_format(header):
expected_headers = ['', 'file_name', 'xmin', 'xmax', 'text', 'is_unit_start_pred']
return header == expected_headers
def validate_row(row):
if len(row) < 6:
return False, "Row does not have enough columns."
try:
# Validate data types
filename = row[1].strip()
xmin = float(row[2])
xmax = float(row[3])
text = row[4].strip()
is_unit_start_pred = row[5].strip().lower() in ["true", "false"]
# Check time consistency
if xmin >= xmax:
return False, "xmin must be less than xmax."
# Check text content
if not re.match("^[a-zA-Z0-9 ,.!?]*$", text): # Allow letters, numbers, spaces, and some punctuation
return False, "Text contains invalid characters."
return True, ""
except ValueError as e:
return False, f"Value error: {e}"
# ==== Gradio Interface Function ====
def csv_to_textgrid(file, tier_name="generated_tier"):
try:
# Create temporary directory
temp_dir = tempfile.mkdtemp()
csv_path = os.path.join(temp_dir, "input.csv")
# Handle different file object types
if hasattr(file, 'name'):
if isinstance(file, str):
with open(file, 'r') as f:
file_content = f.read()
else:
try:
if hasattr(file, 'read'):
file_content = file.read()
else:
file_content = file.decode('utf-8') if isinstance(file, bytes) else str(file)
except Exception as e:
print(f"Error reading file: {e}")
file_content = str(file)
with open(csv_path, 'w', encoding='utf-8') as f:
f.write(file_content)
else:
with open(csv_path, 'w', encoding='utf-8') as f:
f.write(str(file))
print(f"CSV file written to {csv_path}")
output_directory = os.path.join(temp_dir, "textgrids")
os.makedirs(output_directory, exist_ok=True)
# Process the CSV file
processed_files = []
try:
with open(csv_path, 'r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
header = next(reader) # Skip header
if not validate_csv_format(header):
return None, "Invalid CSV format. Expected headers: , file_name, xmin, xmax, text, is_unit_start_pred"
print(f"Header: {header}")
iu_xmin = 0
iu_xmax = 0
intervals = []
words = []
prev_filename = None
current_file_processed = False
for row in reader:
valid, message = validate_row(row)
if len(row) < 6:
print(f"Skipping invalid row: {row}")
continue
try:
filename = row[1].strip() if len(row) > 1 and row[1].strip() else None
if not filename:
print(f"Skipping row with no filename: {row}")
continue
xmin = float(row[2]) if row[2].strip() else 0
xmax = float(row[3]) if row[3].strip() else 0
text = row[4].strip() if len(row) > 4 else ""
is_unit_start_pred_str = row[5].strip().lower() if len(row) > 5 else "false"
is_unit_start_pred = is_unit_start_pred_str == "true"
print(f"Processing: {filename}, {xmin}, {xmax}, {text}, {is_unit_start_pred}")
except (ValueError, IndexError) as e:
print(f"Error processing row {row}: {e}")
continue
# Handle file transition
if prev_filename is not None and prev_filename != filename:
if words:
intervals.append(create_interval_data_dict(iu_xmin, iu_xmax, ' '.join(words)))
if intervals:
last_xmax = intervals[-1]['xmax']
intervals.append(create_interval_data_dict(last_xmax, last_xmax + 0.001, '')) # New interval
if intervals:
textgrid_path = os.path.join(output_directory, f"{prev_filename}.TextGrid")
write_textgrid_file(intervals, textgrid_path, intervals[-1]['xmax'], tier_name)
processed_files.append(prev_filename)
print(f"Wrote file: {prev_filename}.TextGrid with {len(intervals)} intervals")
iu_xmin = 0
iu_xmax = 0
intervals = []
words = []
current_file_processed = True
prev_filename = filename
current_file_processed = False
if is_unit_start_pred:
prev_xmax = intervals[-1]['xmax'] if intervals else 0
if prev_xmax != iu_xmin:
intervals.append(create_interval_data_dict(prev_xmax, iu_xmin, ''))
if words:
intervals.append(create_interval_data_dict(iu_xmin, iu_xmax, ' '.join(words)))
words = []
iu_xmin = xmin
words.append(text)
iu_xmax = xmax
# Process the last file
if not current_file_processed and prev_filename:
if words:
intervals.append(create_interval_data_dict(iu_xmin, iu_xmax, ' '.join(words)))
# Add the new interval with xmin as last xmax and xmax as last xmax + 0.001
if intervals:
last_xmax = intervals[-1]['xmax']
new_xmin = last_xmax
new_xmax = last_xmax + 0.001
# Only add the new interval if it's not a duplicate
if new_xmin < new_xmax: # Ensure they are not the same
intervals.append(create_interval_data_dict(new_xmin, new_xmax, '')) # New interval
if intervals:
textgrid_path = os.path.join(output_directory, f"{prev_filename}.TextGrid")
write_textgrid_file(intervals, textgrid_path, intervals[-1]['xmax'], tier_name)
processed_files.append(prev_filename)
print(f"Wrote last file: {prev_filename}.TextGrid with {len(intervals)} intervals")
except Exception as e:
print(f"Error processing CSV: {e}")
return None, f"Error processing CSV: {e}"
# Create zip file
if processed_files:
zip_path = os.path.join(temp_dir, "textgrids.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for tg_file in Path(output_directory).rglob("*.TextGrid"):
zipf.write(tg_file, tg_file.name)
print(f"Added to zip: {tg_file.name}")
if os.path.exists(zip_path) and os.path.getsize(zip_path) > 0:
return zip_path, f"Successfully processed {len(processed_files)} files: {', '.join(processed_files)}"
else:
return None, "Zip file creation failed."
else:
return None, "No files were processed. Please check your CSV format."
except Exception as e:
print(f"Error in processing: {str(e)}")
return None, f"Error: {str(e)}"
# ==== Gradio Interface Setup ====
csv_format_instruction = """
**Expected CSV Format:**
Please ensure that the CSV file adheres to the following format:\n
- The first row must contain headers: `, file_name, xmin, xmax, text, is_unit_start_pred`.
- Each subsequent row should contain the following columns for every word or segment in the audio file:
- `file_name`: Identifier for the audio file, used to group intervals.
- `xmin`: Start time of the segment (in seconds).
- `xmax`: End time of the segment (in seconds).
- `text`: The actual spoken word or phrase.
- `is_unit_start_pred`: Marks the beginning of a new unit (TRUE/FALSE).
**Please note: We currently only accept CSVs with an index.**
**Tier Name:**
Please enter the tier name according to your preference or as deemed appropriate for the data.
**Example CSV:**
| | file_name | xmin | xmax | text | is_unit_start_pred |
|---|------------|--------|--------|-------|--------------------|
| 0 | example1 | 20.42 | 20.74 | mhmm | TRUE |
| 1 | example1 | 20.74 | 20.81 | hello | TRUE |
| 2 | example1 | 20.81 | 20.92 | world | FALSE |
"""
iface = gr.Interface(
fn=csv_to_textgrid,
inputs=[
gr.File(label="πŸ“ Upload CSV File", file_types=[".csv"]),
gr.Textbox(label="πŸ“ Enter Tier Name", placeholder="Enter the name of the tier") # New input for tier name
],
outputs=[
gr.File(label="πŸ“¦ Download TextGrid ZIP"),
gr.Textbox(label="βœ… Status")
],
title="CSV2Praat Auto Tool",
description="Upload a properly formatted CSV file and receive one or more Praat TextGrid files in return. Files will be packed in a .zip for download.\n\n" + csv_format_instruction,
)
if __name__ == "__main__":
iface.launch()