File size: 6,027 Bytes
7804d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import xml.etree.ElementTree as ET
import json
import os
from typing import List, Dict, Any

def parse_coords(coords_str: str) -> List[List[float]]:
    """
    Convert coordinates string "x1,y1 x2,y2 x3,y3 x4,y4" to LabelMe polygon format
    """
    points = []
    coord_pairs = coords_str.strip().split()
    
    for pair in coord_pairs:
        x, y = pair.split(',')
        points.append([float(x), float(y)])
    
    return points

def xml_to_labelme(xml_file_path: str, output_dir: str = None) -> Dict[str, Any]:
    """
    Convert XML table annotation to LabelMe JSON format
    
    Args:
        xml_file_path: Path to input XML file
        output_dir: Output directory for JSON file (optional)
    
    Returns:
        Dictionary containing LabelMe format data
    """
    
    # Parse XML
    try:
        tree = ET.parse(xml_file_path)
        root = tree.getroot()
    except ET.ParseError as e:
        raise ValueError(f"Invalid XML format: {e}")
    
    # Get image filename from XML
    image_filename = root.get('filename', 'image.jpg')
    
    # Initialize LabelMe structure
    labelme_data = {
        "version": "5.0.1",
        "flags": {},
        "shapes": [],
        "imagePath": image_filename,
        "imageData": None,
        "imageHeight": 0,  # Will be updated if we can get image dimensions
        "imageWidth": 0    # Will be updated if we can get image dimensions
    }
    
    # Process all table elements (can be multiple tables in one XML)
    tables = root.findall('table')
    table_count = 0
    cell_count = 0
    
    for table_idx, table in enumerate(tables):
        # Add table shape
        table_coords = table.find('Coords')
        if table_coords is not None:
            points_str = table_coords.get('points')
            if points_str:
                table_points = parse_coords(points_str)
                
                table_shape = {
                    "label": "table",
                    "points": table_points,
                    "group_id": f"table_{table_idx}",  # Group ID to identify which table
                    "shape_type": "polygon",
                    "flags": {},
                    "description": f"Table {table_idx + 1}"
                }
                labelme_data["shapes"].append(table_shape)
                table_count += 1
        
        # Process all cells in this table
        cells = table.findall('cell')
        for cell_idx, cell in enumerate(cells):
            cell_coords = cell.find('Coords')
            if cell_coords is not None:
                points_str = cell_coords.get('points')
                if points_str:
                    cell_points = parse_coords(points_str)
                    
                    # Get cell attributes for additional info
                    start_row = cell.get('start-row', '0')
                    end_row = cell.get('end-row', '0')
                    start_col = cell.get('start-col', '0')
                    end_col = cell.get('end-col', '0')
                    
                    cell_shape = {
                        "label": "cell",
                        "points": cell_points,
                        "group_id": f"table_{table_idx}",  # Same group ID as parent table
                        "shape_type": "polygon",
                        "flags": {},
                        "description": f"Table {table_idx + 1} - Row:{start_row}-{end_row}, Col:{start_col}-{end_col}"
                    }
                    labelme_data["shapes"].append(cell_shape)
                    cell_count += 1
    
    # Try to estimate image dimensions from coordinates
    all_x = []
    all_y = []
    for shape in labelme_data["shapes"]:
        for point in shape["points"]:
            all_x.append(point[0])
            all_y.append(point[1])
    
    if all_x and all_y:
        labelme_data["imageWidth"] = int(max(all_x)) + 50  # Add some padding
        labelme_data["imageHeight"] = int(max(all_y)) + 50  # Add some padding
    
    # Save to JSON file
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        
        # Create output filename
        base_name = os.path.splitext(os.path.basename(xml_file_path))[0]
        json_filename = f"{base_name}.json"
        json_path = os.path.join(output_dir, json_filename)
        
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(labelme_data, f, indent=2, ensure_ascii=False)
        
        print(f"Converted successfully! Output saved to: {json_path}")
        print(f"Found {len(labelme_data['shapes'])} shapes total:")
        print(f"  - Tables: {table_count}")
        print(f"  - Cells: {cell_count}")
        if table_count > 0:
            print(f"  - Average cells per table: {cell_count / table_count:.1f}")
    
    return labelme_data

def batch_convert(input_dir: str, output_dir: str):
    """
    Convert all XML files in a directory to LabelMe JSON format
    
    Args:
        input_dir: Directory containing XML files
        output_dir: Directory to save JSON files
    """
    
    if not os.path.exists(input_dir):
        raise ValueError(f"Input directory does not exist: {input_dir}")
    
    xml_files = [f for f in os.listdir(input_dir) if f.endswith('.xml')]
    
    if not xml_files:
        print(f"No XML files found in {input_dir}")
        return
    
    print(f"Found {len(xml_files)} XML files to convert...")
    
    success_count = 0
    for xml_file in xml_files:
        try:
            xml_path = os.path.join(input_dir, xml_file)
            xml_to_labelme(xml_path, output_dir)
            success_count += 1
        except Exception as e:
            print(f"Error converting {xml_file}: {e}")
    
    print(f"\nConversion completed! Successfully converted {success_count}/{len(xml_files)} files.")

# Example usage
if __name__ == "__main__":
    batch_convert('/Users/tuvn18/Desktop/tuvn18/dev/KIAI/dev/trace/src/train_trace_page39', '/Users/tuvn18/Desktop/tuvn18/dev/KIAI/dev/trace/src/train_trace_page39')