File size: 6,027 Bytes
7804d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import xml.etree.ElementTree as ET
import json
import os
from typing import List, Dict, Any
def parse_coords(coords_str: str) -> List[List[float]]:
"""
Convert coordinates string "x1,y1 x2,y2 x3,y3 x4,y4" to LabelMe polygon format
"""
points = []
coord_pairs = coords_str.strip().split()
for pair in coord_pairs:
x, y = pair.split(',')
points.append([float(x), float(y)])
return points
def xml_to_labelme(xml_file_path: str, output_dir: str = None) -> Dict[str, Any]:
"""
Convert XML table annotation to LabelMe JSON format
Args:
xml_file_path: Path to input XML file
output_dir: Output directory for JSON file (optional)
Returns:
Dictionary containing LabelMe format data
"""
# Parse XML
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
except ET.ParseError as e:
raise ValueError(f"Invalid XML format: {e}")
# Get image filename from XML
image_filename = root.get('filename', 'image.jpg')
# Initialize LabelMe structure
labelme_data = {
"version": "5.0.1",
"flags": {},
"shapes": [],
"imagePath": image_filename,
"imageData": None,
"imageHeight": 0, # Will be updated if we can get image dimensions
"imageWidth": 0 # Will be updated if we can get image dimensions
}
# Process all table elements (can be multiple tables in one XML)
tables = root.findall('table')
table_count = 0
cell_count = 0
for table_idx, table in enumerate(tables):
# Add table shape
table_coords = table.find('Coords')
if table_coords is not None:
points_str = table_coords.get('points')
if points_str:
table_points = parse_coords(points_str)
table_shape = {
"label": "table",
"points": table_points,
"group_id": f"table_{table_idx}", # Group ID to identify which table
"shape_type": "polygon",
"flags": {},
"description": f"Table {table_idx + 1}"
}
labelme_data["shapes"].append(table_shape)
table_count += 1
# Process all cells in this table
cells = table.findall('cell')
for cell_idx, cell in enumerate(cells):
cell_coords = cell.find('Coords')
if cell_coords is not None:
points_str = cell_coords.get('points')
if points_str:
cell_points = parse_coords(points_str)
# Get cell attributes for additional info
start_row = cell.get('start-row', '0')
end_row = cell.get('end-row', '0')
start_col = cell.get('start-col', '0')
end_col = cell.get('end-col', '0')
cell_shape = {
"label": "cell",
"points": cell_points,
"group_id": f"table_{table_idx}", # Same group ID as parent table
"shape_type": "polygon",
"flags": {},
"description": f"Table {table_idx + 1} - Row:{start_row}-{end_row}, Col:{start_col}-{end_col}"
}
labelme_data["shapes"].append(cell_shape)
cell_count += 1
# Try to estimate image dimensions from coordinates
all_x = []
all_y = []
for shape in labelme_data["shapes"]:
for point in shape["points"]:
all_x.append(point[0])
all_y.append(point[1])
if all_x and all_y:
labelme_data["imageWidth"] = int(max(all_x)) + 50 # Add some padding
labelme_data["imageHeight"] = int(max(all_y)) + 50 # Add some padding
# Save to JSON file
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Create output filename
base_name = os.path.splitext(os.path.basename(xml_file_path))[0]
json_filename = f"{base_name}.json"
json_path = os.path.join(output_dir, json_filename)
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(labelme_data, f, indent=2, ensure_ascii=False)
print(f"Converted successfully! Output saved to: {json_path}")
print(f"Found {len(labelme_data['shapes'])} shapes total:")
print(f" - Tables: {table_count}")
print(f" - Cells: {cell_count}")
if table_count > 0:
print(f" - Average cells per table: {cell_count / table_count:.1f}")
return labelme_data
def batch_convert(input_dir: str, output_dir: str):
"""
Convert all XML files in a directory to LabelMe JSON format
Args:
input_dir: Directory containing XML files
output_dir: Directory to save JSON files
"""
if not os.path.exists(input_dir):
raise ValueError(f"Input directory does not exist: {input_dir}")
xml_files = [f for f in os.listdir(input_dir) if f.endswith('.xml')]
if not xml_files:
print(f"No XML files found in {input_dir}")
return
print(f"Found {len(xml_files)} XML files to convert...")
success_count = 0
for xml_file in xml_files:
try:
xml_path = os.path.join(input_dir, xml_file)
xml_to_labelme(xml_path, output_dir)
success_count += 1
except Exception as e:
print(f"Error converting {xml_file}: {e}")
print(f"\nConversion completed! Successfully converted {success_count}/{len(xml_files)} files.")
# Example usage
if __name__ == "__main__":
batch_convert('/Users/tuvn18/Desktop/tuvn18/dev/KIAI/dev/trace/src/train_trace_page39', '/Users/tuvn18/Desktop/tuvn18/dev/KIAI/dev/trace/src/train_trace_page39') |