File size: 4,573 Bytes
2d984ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# import csv
# import random
import pandas as pd
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
from typing import Literal
import os
from datetime import datetime


# ──────────────────────────────────────────────────────────────────────────────
# Load the fine-tuned model and run inference on each prompt
# ──────────────────────────────────────────────────────────────────────────────

def load_model_and_tokenizer(model_dir="./past_ref_classifier/updated_model"):
    """
    Load tokenizer and model. Adjust model_dir if needed.
    """
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir)
    model = DistilBertForSequenceClassification.from_pretrained(model_dir)
    model.eval()
    return tokenizer, model

@torch.no_grad()
def classify_prompts(df, tokenizer, model, max_length=128, device="cuda" if torch.cuda.is_available() else "cpu"):
    """
    Take a DataFrame with 'text' column, run the classifier, and return:
    - pred_label: 0 or 1
    - prob_past: probability of label=1
    """
    model.to(device)
    pred_labels = []
    prob_pasts = []
    for i, txt in enumerate(df["text"]):
        inputs = tokenizer(
            txt,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        logits = outputs.logits.squeeze()  # shape: (2,)
        probs = torch.softmax(logits, dim=-1)
        prob_past = probs[1].item()
        pred_label = int(prob_past >= 0.5)

        pred_labels.append(pred_label)
        prob_pasts.append(prob_past)

        if (i + 1) % 50 == 0:
            print(f"Classified {i+1}/{len(df)} prompts")

    df["pred_label"] = pred_labels
    df["prob_past"] = prob_pasts
    return df


def read_txt_as_dataframe(txt_input):
    # Read and strip lines, dropping any blank ones
    if os.path.isfile(txt_input):
        with open(txt_input, 'r', encoding='utf-8') as f:
            raw = f.read()
    else:
        # Assume txt_input itself is the text content
        raw = txt_input

    # Split into lines, strip whitespace, remove blanks
    lines = [line.strip() for line in raw.splitlines() if line.strip()]

    # Remove 2nd line if it's "[" 
    if len(lines) > 1 and lines[0] == "[":
        lines.pop(0)

    # Remove last line if it's "]"
    if lines and lines[-1] == "]":
        lines.pop(-1)

    # Build DataFrame
    df = pd.DataFrame(lines, columns=['text'])
    return df

AllowedMode = Literal['txt_file_path', 'txt_file', 'csv_file_path', "csv_file"]
AllowedOut = Literal[True, False]

def run_tagging(mode: AllowedMode, data_or_path="", out_dir=".", prefix="data", out_as_a_df_variable: AllowedOut = False):
    
    
    if mode=="csv_file" or mode=="csv_file_path":
        df = pd.read_csv(data_or_path)
    elif mode=="txt_file_path" or mode=="txt_file":
        df = read_txt_as_dataframe(data_or_path)
    else:
        return 0
    # Load model + tokenizer
    tokenizer, model = load_model_and_tokenizer(
        model_dir="./past_ref_classifier/updated_model_3"
    )

    #Classify each prompt
    df_results = classify_prompts(df, tokenizer, model)

    # Print first 20 results to console, and save full CSV
    print("\nFirst 20 inference results:\n")
    print(df_results.head(20).to_string(index=False))

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{prefix}_{ts}.csv"
    full_path = f"{out_dir.rstrip('/')}/{filename}"

    df_results.to_csv(full_path, index=False)
    print(f"\nSaved full results (with pred_label and prob_past) to {filename}")

    if out_as_a_df_variable ==True:
        return df_results


if __name__ == "__main__":
    runMode = int(input("Please select a running mode:\n\n1. Txt file path\n2. Csv file path\n\n"))
    if runMode>0 and runMode<5:
        if runMode==1:
            path_to_txt=input("Please provide path to the txt file\n")
            run_tagging(mode="txt_file_path", data_or_path=path_to_txt)
        elif runMode==2:
            path_to_csv=input("Please provide path to the csv file\n")
            run_tagging(mode="csv_file_path", data_or_path=path_to_csv)