finbertteacher_v1 / csv_inference.py
aimlresearch2023's picture
initial commit:
89c54bf
"""CSV validation and batch export (no Gradio)."""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
from typing import Callable, Optional
import pandas as pd
from inference import classify_batch
MAX_BYTES = 200 * 1024 * 1024
def validate_csv_path(path: Optional[str]) -> tuple[bool, str]:
if not path:
return False, ""
if not os.path.isfile(path):
return False, "Not a file"
try:
size = os.path.getsize(path)
except OSError as exc:
print("validate_csv_path size:", exc)
return False, "Could not read file"
if size > MAX_BYTES:
return False, "File exceeds 20 MB (only the text column is used)."
try:
df = pd.read_csv(path, nrows=0)
except Exception as exc:
print("validate_csv_path read_csv:", exc)
return False, "Could not parse CSV header"
if "text" not in df.columns:
return False, 'CSV must include a column named exactly "text".'
return True, ""
def export_path_for_source(source_path: str) -> str:
stem = Path(source_path).stem
out_name = f"{stem}_v3_teacher_results.csv"
out_dir = tempfile.mkdtemp(prefix="bill_csv_")
return str(Path(out_dir) / out_name)
def run_csv_inference(
csv_path: str,
progress_cb: Optional[Callable[[int, int], None]] = None,
) -> str:
df = pd.read_csv(csv_path)
texts = ["" if pd.isna(x) else str(x) for x in df["text"].tolist()]
batch_rows = classify_batch(texts, batch_size=32, progress_cb=progress_cb)
if batch_rows:
v3_teacher_results, v3_teacher_label, v3_teacher_prob = zip(*batch_rows)
df["v3_teacher_results"] = list(v3_teacher_results)
df["v3_teacher_label"] = list(v3_teacher_label)
df["v3_teacher_prob"] = list(v3_teacher_prob)
else:
df["v3_teacher_results"] = []
df["v3_teacher_label"] = []
df["v3_teacher_prob"] = []
out_path = export_path_for_source(csv_path)
df.to_csv(out_path, index=False)
return out_path