|
|
import xml.etree.ElementTree as ET |
|
|
import json |
|
|
import os |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
def parse_coords(coords_str: str) -> List[List[float]]: |
|
|
""" |
|
|
Convert coordinates string "x1,y1 x2,y2 x3,y3 x4,y4" to LabelMe polygon format |
|
|
""" |
|
|
points = [] |
|
|
coord_pairs = coords_str.strip().split() |
|
|
|
|
|
for pair in coord_pairs: |
|
|
x, y = pair.split(',') |
|
|
points.append([float(x), float(y)]) |
|
|
|
|
|
return points |
|
|
|
|
|
def xml_to_labelme(xml_file_path: str, output_dir: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Convert XML table annotation to LabelMe JSON format |
|
|
|
|
|
Args: |
|
|
xml_file_path: Path to input XML file |
|
|
output_dir: Output directory for JSON file (optional) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing LabelMe format data |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
tree = ET.parse(xml_file_path) |
|
|
root = tree.getroot() |
|
|
except ET.ParseError as e: |
|
|
raise ValueError(f"Invalid XML format: {e}") |
|
|
|
|
|
|
|
|
image_filename = root.get('filename', 'image.jpg') |
|
|
|
|
|
|
|
|
labelme_data = { |
|
|
"version": "5.0.1", |
|
|
"flags": {}, |
|
|
"shapes": [], |
|
|
"imagePath": image_filename, |
|
|
"imageData": None, |
|
|
"imageHeight": 0, |
|
|
"imageWidth": 0 |
|
|
} |
|
|
|
|
|
|
|
|
tables = root.findall('table') |
|
|
table_count = 0 |
|
|
cell_count = 0 |
|
|
|
|
|
for table_idx, table in enumerate(tables): |
|
|
|
|
|
table_coords = table.find('Coords') |
|
|
if table_coords is not None: |
|
|
points_str = table_coords.get('points') |
|
|
if points_str: |
|
|
table_points = parse_coords(points_str) |
|
|
|
|
|
table_shape = { |
|
|
"label": "table", |
|
|
"points": table_points, |
|
|
"group_id": f"table_{table_idx}", |
|
|
"shape_type": "polygon", |
|
|
"flags": {}, |
|
|
"description": f"Table {table_idx + 1}" |
|
|
} |
|
|
labelme_data["shapes"].append(table_shape) |
|
|
table_count += 1 |
|
|
|
|
|
|
|
|
cells = table.findall('cell') |
|
|
for cell_idx, cell in enumerate(cells): |
|
|
cell_coords = cell.find('Coords') |
|
|
if cell_coords is not None: |
|
|
points_str = cell_coords.get('points') |
|
|
if points_str: |
|
|
cell_points = parse_coords(points_str) |
|
|
|
|
|
|
|
|
start_row = cell.get('start-row', '0') |
|
|
end_row = cell.get('end-row', '0') |
|
|
start_col = cell.get('start-col', '0') |
|
|
end_col = cell.get('end-col', '0') |
|
|
|
|
|
cell_shape = { |
|
|
"label": "cell", |
|
|
"points": cell_points, |
|
|
"group_id": f"table_{table_idx}", |
|
|
"shape_type": "polygon", |
|
|
"flags": {}, |
|
|
"description": f"Table {table_idx + 1} - Row:{start_row}-{end_row}, Col:{start_col}-{end_col}" |
|
|
} |
|
|
labelme_data["shapes"].append(cell_shape) |
|
|
cell_count += 1 |
|
|
|
|
|
|
|
|
all_x = [] |
|
|
all_y = [] |
|
|
for shape in labelme_data["shapes"]: |
|
|
for point in shape["points"]: |
|
|
all_x.append(point[0]) |
|
|
all_y.append(point[1]) |
|
|
|
|
|
if all_x and all_y: |
|
|
labelme_data["imageWidth"] = int(max(all_x)) + 50 |
|
|
labelme_data["imageHeight"] = int(max(all_y)) + 50 |
|
|
|
|
|
|
|
|
if output_dir: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(xml_file_path))[0] |
|
|
json_filename = f"{base_name}.json" |
|
|
json_path = os.path.join(output_dir, json_filename) |
|
|
|
|
|
with open(json_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(labelme_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"Converted successfully! Output saved to: {json_path}") |
|
|
print(f"Found {len(labelme_data['shapes'])} shapes total:") |
|
|
print(f" - Tables: {table_count}") |
|
|
print(f" - Cells: {cell_count}") |
|
|
if table_count > 0: |
|
|
print(f" - Average cells per table: {cell_count / table_count:.1f}") |
|
|
|
|
|
return labelme_data |
|
|
|
|
|
def batch_convert(input_dir: str, output_dir: str): |
|
|
""" |
|
|
Convert all XML files in a directory to LabelMe JSON format |
|
|
|
|
|
Args: |
|
|
input_dir: Directory containing XML files |
|
|
output_dir: Directory to save JSON files |
|
|
""" |
|
|
|
|
|
if not os.path.exists(input_dir): |
|
|
raise ValueError(f"Input directory does not exist: {input_dir}") |
|
|
|
|
|
xml_files = [f for f in os.listdir(input_dir) if f.endswith('.xml')] |
|
|
|
|
|
if not xml_files: |
|
|
print(f"No XML files found in {input_dir}") |
|
|
return |
|
|
|
|
|
print(f"Found {len(xml_files)} XML files to convert...") |
|
|
|
|
|
success_count = 0 |
|
|
for xml_file in xml_files: |
|
|
try: |
|
|
xml_path = os.path.join(input_dir, xml_file) |
|
|
xml_to_labelme(xml_path, output_dir) |
|
|
success_count += 1 |
|
|
except Exception as e: |
|
|
print(f"Error converting {xml_file}: {e}") |
|
|
|
|
|
print(f"\nConversion completed! Successfully converted {success_count}/{len(xml_files)} files.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
batch_convert('/Users/tuvn18/Desktop/tuvn18/dev/KIAI/dev/trace/src/train_trace_page39', '/Users/tuvn18/Desktop/tuvn18/dev/KIAI/dev/trace/src/train_trace_page39') |