File size: 1,918 Bytes
51f795f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd91161
51f795f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
from pathlib import Path
import xml.etree.ElementTree as ET

def build_xml(dir_path: Path, root_level: bool = True) -> ET.Element:
    """
    Recursively build the XML tree.
    Root: <dataset>
    Subfolders: <folder name="...">
    Files: <file name="...">
    """
    if root_level:
        elem = ET.Element("dataset")
    else:
        elem = ET.Element("folder", {"name": dir_path.name})

    # Add files
    for file in sorted(dir_path.iterdir()):
        if file.is_file():
            elem.append(ET.Element("file", {"name": file.name}))

    # Add subdirectories
    for sub in sorted(dir_path.iterdir()):
        if sub.is_dir():
            elem.append(build_xml(sub, root_level=False))

    return elem

def indent(elem: ET.Element, level: int = 0) -> None:
    """
    In-place pretty printer (ElementTree doesn't indent by default).
    """
    i = "\n" + ("  " * level)
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "  "
        for child in elem:
            indent(child, level + 1)
        if not child.tail or not child.tail.strip():  # last child's tail
            child.tail = i
    if level and (not elem.tail or not elem.tail.strip()):
        elem.tail = i

def main(directory: str = "sample_dataset") -> None:
    path = Path(directory)
    if not path.is_dir():
        print(f"Error: '{directory}' is not a directory.", file=sys.stderr)
        sys.exit(1)

    root_elem = build_xml(path, root_level=True)
    indent(root_elem)

    tree = ET.ElementTree(root_elem)
    # Write to stdout with XML declaration
    xml_bytes = ET.tostring(root_elem, encoding="utf-8")
    xml_string = b'<?xml version="1.0" encoding="UTF-8"?>\n' + xml_bytes
    print(xml_string.decode("utf-8"))

if __name__ == "__main__":
    # Optional: allow passing a custom path
    if len(sys.argv) > 1:
        main(sys.argv[1])
    else:
        main()