serviceadvisor / training /train_fusion.py
viswanani's picture
Upload 22 files
1c7bc31 verified
import argparse, os, pandas as pd, numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import joblib
def main(args):
df = pd.read_csv(args.annotations)
labels = sorted(df["issue_label"].unique().tolist())
label_to_idx = {l:i for i,l in enumerate(labels)}
X = []
y = []
for _, row in df.iterrows():
text = str(row.get("customer_text","")).lower()
features = [
len(text),
int("brake" in text),
int("leak" in text),
int("tire" in text or "tyre" in text),
int("scratch" in text or "dent" in text),
]
X.append(features)
y.append(label_to_idx[row["issue_label"]])
X = np.array(X); y = np.array(y)
Xtr, Xv, ytr, yv = train_test_split(X, y, test_size=0.2, random_state=42)
clf = LogisticRegression(max_iter=200).fit(Xtr, ytr)
yp = clf.predict(Xv)
print("fusion macro F1:", f1_score(yv, yp, average="macro"))
os.makedirs(args.out_dir, exist_ok=True)
joblib.dump({"clf": clf, "labels": labels}, os.path.join(args.out_dir, "best.joblib"))
print("Saved", args.out_dir)
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--annotations", required=True)
ap.add_argument("--vision_ckpt", required=False)
ap.add_argument("--nlp_ckpt", required=False)
ap.add_argument("--out_dir", default="checkpoints/fusion")
args = ap.parse_args()
main(args)