| import os |
| import io |
| import json |
| import socket |
| import asyncio |
| import threading |
| import zipfile |
| from pathlib import Path |
|
|
| import pandas as pd |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| from pyrogram import Client |
|
|
| |
| name = "NeonTech" |
| session_string = os.environ.get("session_string") |
| api_id = os.environ.get("api_id") |
| api_hash = os.environ.get("api_hash") |
| channel = os.environ.get("channel") |
| OUT_DIR = "/data/image-shards" |
| STATE_FILE = "/data/images_state.json" |
| IMAGES_PER_SHARD = 1000 |
| DOWNLOAD_DELAY = 10 |
|
|
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| |
| def serve(): |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| s.bind(("0.0.0.0", 7860)) |
| s.listen(5) |
| print("β Listening on port 7860") |
| while True: |
| conn, _ = s.accept() |
| conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") |
| conn.close() |
|
|
| |
| def load_state(): |
| if os.path.exists(STATE_FILE): |
| with open(STATE_FILE) as f: |
| state = json.load(f) |
| print(f"Resuming β {len(state['done_msgs'])} zips done | {state['total_images']} images | {state['total_shards']} shards") |
| else: |
| state = {"done_msgs": [], "total_images": 0, "total_shards": 0} |
| print("Starting fresh") |
| return state |
|
|
| def save_state(state): |
| try: |
| with open(STATE_FILE, "w") as f: |
| json.dump(state, f, indent=2) |
| except OSError as e: |
| print(f" β State save failed: {e}") |
|
|
| |
| def write_shard(rows, shard_idx): |
| shard_path = Path(OUT_DIR) / f"shard_{shard_idx:06d}.parquet" |
| df = pd.DataFrame(rows) |
| table = pa.Table.from_pandas(df) |
| pq.write_table(table, shard_path, compression="snappy") |
| print(f" β shard_{shard_idx:06d}.parquet | {len(rows)} images") |
|
|
| |
| def parse_zip(zf): |
| names = set(zf.namelist()) |
| img_names = [n for n in names if not n.endswith(".meta.json") |
| and not n.endswith(".json") |
| and not n.endswith(".txt") |
| and not n.endswith("/")] |
|
|
| |
| batch_captions = {} |
| batch_tags = {} |
| if "captions.txt" in names: |
| for line in zf.read("captions.txt").decode("utf-8").splitlines(): |
| if "||" in line: |
| fname, cap = line.split("||", 1) |
| batch_captions[fname.strip()] = cap.strip() |
| if "tags.json" in names: |
| raw = json.loads(zf.read("tags.json")) |
| batch_tags = {k: v for k, v in raw.items()} |
|
|
| rows = [] |
| for img_name in img_names: |
| image_id = Path(img_name).stem |
| try: |
| img_bytes = zf.read(img_name) |
| except Exception: |
| continue |
|
|
| description = "" |
| tags = [] |
|
|
| |
| meta_name = img_name + ".meta.json" |
| if meta_name in names: |
| try: |
| meta = json.loads(zf.read(meta_name)) |
| description = meta.get("description", "") |
| tags = meta.get("tags", []) |
| except Exception: |
| pass |
| |
| elif img_name in batch_captions or img_name in batch_tags: |
| description = batch_captions.get(img_name, "") |
| tags = batch_tags.get(img_name, []) |
|
|
| rows.append({ |
| "image_id": image_id, |
| "image": img_bytes, |
| "description": description, |
| "tags": tags, |
| }) |
|
|
| return rows |
|
|
| |
| async def process(state): |
| buffer = [] |
| shard_idx = state["total_shards"] |
|
|
| async with Client(name=name, session_string=session_string, api_id=api_id, api_hash=api_hash) as app: |
| async for msg in app.get_chat_history(channel): |
| if not msg.document: |
| continue |
| if not msg.document.file_name or not msg.document.file_name.endswith(".zip"): |
| continue |
| if msg.id in state["done_msgs"]: |
| print(f" SKIP msg {msg.id}") |
| continue |
|
|
| print(f"Downloading msg {msg.id} β {msg.document.file_name} ({msg.document.file_size:,} bytes)") |
|
|
| for attempt in range(10): |
| try: |
| zip_bytes = io.BytesIO() |
| async for chunk in app.stream_media(msg): |
| zip_bytes.write(chunk) |
| zip_bytes.seek(0) |
| break |
| except Exception as e: |
| wait = 300 * (attempt + 1) |
| print(f" β Attempt {attempt + 1} failed: {e} β waiting {wait//60}min") |
| await asyncio.sleep(wait) |
| else: |
| print(f" β Telegram unavailable after 25min β stopping") |
| return |
|
|
| try: |
| with zipfile.ZipFile(zip_bytes) as zf: |
| rows = parse_zip(zf) |
|
|
| for row in rows: |
| buffer.append(row) |
| state["total_images"] += 1 |
|
|
| if len(buffer) >= IMAGES_PER_SHARD: |
| write_shard(buffer, shard_idx) |
| shard_idx += 1 |
| state["total_shards"] = shard_idx |
| buffer = [] |
| save_state(state) |
|
|
| except Exception as e: |
| print(f" β zip error msg {msg.id}: {e}") |
| continue |
|
|
| state["done_msgs"].append(msg.id) |
| save_state(state) |
| print(f" β msg {msg.id} done | {state['total_images']} images total") |
| await asyncio.sleep(DOWNLOAD_DELAY) |
|
|
| if buffer: |
| write_shard(buffer, shard_idx) |
| state["total_shards"] = shard_idx + 1 |
| save_state(state) |
|
|
| print(f"\nβ All done! {state['total_images']} images | {state['total_shards']} shards") |
|
|
| |
| def main(): |
| threading.Thread(target=serve, daemon=True).start() |
| state = load_state() |
| asyncio.run(process(state)) |
|
|
| if __name__ == "__main__": |
| main() |