File size: 1,200 Bytes
701f414
 
 
 
 
 
 
ae57a5a
 
701f414
 
 
 
ae57a5a
 
701f414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import io
import json
import time
import socket
import asyncio
import threading
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd

OUT_DIR = "/data/image-shards"

def serve():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    s.bind(("0.0.0.0", 7860))
    s.listen(5)
    print("✓ Listening on port 7860")
    while True:
        conn, _ = s.accept()
        conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
        conn.close()

def fix_shards():
    shards = list(Path(OUT_DIR).glob("*.parquet"))
    print(f"Found {len(shards)} shards")
    for shard in shards:
        table = pq.read_table(shard)
        df = table.to_pandas()
        df["image"] = df["image"].apply(
            lambda b: {"bytes": b["bytes"], "path": None} if isinstance(b, dict) else {"bytes": b, "path": None}
        )
        pq.write_table(pa.Table.from_pandas(df), shard, compression="snappy")
        print(f"✓ Fixed {shard.name}")
    print("Done")

def main():
    threading.Thread(target=serve, daemon=True).start()
    fix_shards()

if __name__ == "__main__":
    main()