processing / utils /schema.py
LiamKhoaLe's picture
Upd 14/9
80cb919
raw
history blame
2.3 kB
# Centralized SFT writer (JSONL + CSV)
import csv
import orjson
from typing import Optional, Dict
import logging
# Logger
logger = logging.getLogger("schema")
if not logger.handlers:
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
def sft_row(instruction: str, user_input: str, output: str, source: str, rid: str, task: str, meta: Optional[dict] = None):
return {
"source": source,
"id": rid,
"task": task,
"sft": {
"instruction": instruction,
"input": user_input,
"output": output
},
"meta": meta or {}
}
def is_valid_row(row: Dict, max_chars: int = 20000) -> bool:
s = row.get("sft", {})
instr = s.get("instruction", "")
inp = s.get("input", "")
out = s.get("output", "")
# basic sanity: non-empty input OR output; cap extremes
if not (inp or out): return False
if any(len(x) > max_chars for x in (instr, inp, out)): return False
return True
class CentralisedWriter:
"""Streams JSONL + CSV in parallel to stay memory-safe."""
def __init__(self, jsonl_path: str, csv_path: str):
self.jsonl_fp = open(jsonl_path, "wb")
self.csv_fp = open(csv_path, "w", newline="", encoding="utf-8")
self.csv_wr = csv.DictWriter(self.csv_fp, fieldnames=["instruction","input","output","source","id","task"])
self.csv_wr.writeheader()
def write(self, row: dict):
if not is_valid_row(row):
s = row.get("sft", {})
logger.warning(
f"[WRITER] Skipping invalid row id={row.get('id')} "
f"(len instr={len(s.get('instruction',''))}, input={len(s.get('input',''))}, output={len(s.get('output',''))})"
)
return
self.jsonl_fp.write(orjson.dumps(row))
self.jsonl_fp.write(b"\n")
s = row["sft"]
self.csv_wr.writerow({
"instruction": s.get("instruction",""),
"input": s.get("input",""),
"output": s.get("output",""),
"source": row.get("source",""),
"id": row.get("id",""),
"task": row.get("task","")
})
def close(self):
try:
self.jsonl_fp.close()
finally:
self.csv_fp.close()