File size: 2,031 Bytes
89c54bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""CSV validation and batch export (no Gradio)."""

from __future__ import annotations

import os
import tempfile
from pathlib import Path
from typing import Callable, Optional

import pandas as pd

from inference import classify_batch

MAX_BYTES = 200 * 1024 * 1024


def validate_csv_path(path: Optional[str]) -> tuple[bool, str]:
    if not path:
        return False, ""
    if not os.path.isfile(path):
        return False, "Not a file"
    try:
        size = os.path.getsize(path)
    except OSError as exc:
        print("validate_csv_path size:", exc)
        return False, "Could not read file"
    if size > MAX_BYTES:
        return False, "File exceeds 20 MB (only the text column is used)."
    try:
        df = pd.read_csv(path, nrows=0)
    except Exception as exc:
        print("validate_csv_path read_csv:", exc)
        return False, "Could not parse CSV header"
    if "text" not in df.columns:
        return False, 'CSV must include a column named exactly "text".'
    return True, ""


def export_path_for_source(source_path: str) -> str:
    stem = Path(source_path).stem
    out_name = f"{stem}_v3_teacher_results.csv"
    out_dir = tempfile.mkdtemp(prefix="bill_csv_")
    return str(Path(out_dir) / out_name)


def run_csv_inference(
    csv_path: str,
    progress_cb: Optional[Callable[[int, int], None]] = None,
) -> str:
    df = pd.read_csv(csv_path)
    texts = ["" if pd.isna(x) else str(x) for x in df["text"].tolist()]
    batch_rows = classify_batch(texts, batch_size=32, progress_cb=progress_cb)
    if batch_rows:
        v3_teacher_results, v3_teacher_label, v3_teacher_prob = zip(*batch_rows)
        df["v3_teacher_results"] = list(v3_teacher_results)
        df["v3_teacher_label"] = list(v3_teacher_label)
        df["v3_teacher_prob"] = list(v3_teacher_prob)
    else:
        df["v3_teacher_results"] = []
        df["v3_teacher_label"] = []
        df["v3_teacher_prob"] = []
    out_path = export_path_for_source(csv_path)
    df.to_csv(out_path, index=False)
    return out_path