memerchik commited on
Commit
2d984ae
Β·
verified Β·
1 Parent(s): dc24b09

Upload 2 files

Browse files
Files changed (2) hide show
  1. test_cuda.py +6 -0
  2. testing.py +127 -0
test_cuda.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch, platform, sys
2
+ print("Torch:", torch.__version__)
3
+ print("Built for CUDA:", torch.version.cuda)
4
+ print("CUDA available:", torch.cuda.is_available())
5
+ if torch.cuda.is_available():
6
+ print("GPU:", torch.cuda.get_device_name(0))
testing.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import csv
2
+ # import random
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
6
+ from typing import Literal
7
+ import os
8
+ from datetime import datetime
9
+
10
+
11
+ # ──────────────────────────────────────────────────────────────────────────────
12
+ # Load the fine-tuned model and run inference on each prompt
13
+ # ──────────────────────────────────────────────────────────────────────────────
14
+
15
+ def load_model_and_tokenizer(model_dir="./past_ref_classifier/updated_model"):
16
+ """
17
+ Load tokenizer and model. Adjust model_dir if needed.
18
+ """
19
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir)
20
+ model = DistilBertForSequenceClassification.from_pretrained(model_dir)
21
+ model.eval()
22
+ return tokenizer, model
23
+
24
+ @torch.no_grad()
25
+ def classify_prompts(df, tokenizer, model, max_length=128, device="cuda" if torch.cuda.is_available() else "cpu"):
26
+ """
27
+ Take a DataFrame with 'text' column, run the classifier, and return:
28
+ - pred_label: 0 or 1
29
+ - prob_past: probability of label=1
30
+ """
31
+ model.to(device)
32
+ pred_labels = []
33
+ prob_pasts = []
34
+ for i, txt in enumerate(df["text"]):
35
+ inputs = tokenizer(
36
+ txt,
37
+ truncation=True,
38
+ padding="max_length",
39
+ max_length=max_length,
40
+ return_tensors="pt"
41
+ )
42
+ inputs = {k: v.to(device) for k, v in inputs.items()}
43
+ outputs = model(**inputs)
44
+ logits = outputs.logits.squeeze() # shape: (2,)
45
+ probs = torch.softmax(logits, dim=-1)
46
+ prob_past = probs[1].item()
47
+ pred_label = int(prob_past >= 0.5)
48
+
49
+ pred_labels.append(pred_label)
50
+ prob_pasts.append(prob_past)
51
+
52
+ if (i + 1) % 50 == 0:
53
+ print(f"Classified {i+1}/{len(df)} prompts")
54
+
55
+ df["pred_label"] = pred_labels
56
+ df["prob_past"] = prob_pasts
57
+ return df
58
+
59
+
60
+ def read_txt_as_dataframe(txt_input):
61
+ # Read and strip lines, dropping any blank ones
62
+ if os.path.isfile(txt_input):
63
+ with open(txt_input, 'r', encoding='utf-8') as f:
64
+ raw = f.read()
65
+ else:
66
+ # Assume txt_input itself is the text content
67
+ raw = txt_input
68
+
69
+ # Split into lines, strip whitespace, remove blanks
70
+ lines = [line.strip() for line in raw.splitlines() if line.strip()]
71
+
72
+ # Remove 2nd line if it's "["
73
+ if len(lines) > 1 and lines[0] == "[":
74
+ lines.pop(0)
75
+
76
+ # Remove last line if it's "]"
77
+ if lines and lines[-1] == "]":
78
+ lines.pop(-1)
79
+
80
+ # Build DataFrame
81
+ df = pd.DataFrame(lines, columns=['text'])
82
+ return df
83
+
84
+ AllowedMode = Literal['txt_file_path', 'txt_file', 'csv_file_path', "csv_file"]
85
+ AllowedOut = Literal[True, False]
86
+
87
+ def run_tagging(mode: AllowedMode, data_or_path="", out_dir=".", prefix="data", out_as_a_df_variable: AllowedOut = False):
88
+
89
+
90
+ if mode=="csv_file" or mode=="csv_file_path":
91
+ df = pd.read_csv(data_or_path)
92
+ elif mode=="txt_file_path" or mode=="txt_file":
93
+ df = read_txt_as_dataframe(data_or_path)
94
+ else:
95
+ return 0
96
+ # Load model + tokenizer
97
+ tokenizer, model = load_model_and_tokenizer(
98
+ model_dir="./past_ref_classifier/updated_model_3"
99
+ )
100
+
101
+ #Classify each prompt
102
+ df_results = classify_prompts(df, tokenizer, model)
103
+
104
+ # Print first 20 results to console, and save full CSV
105
+ print("\nFirst 20 inference results:\n")
106
+ print(df_results.head(20).to_string(index=False))
107
+
108
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
109
+ filename = f"{prefix}_{ts}.csv"
110
+ full_path = f"{out_dir.rstrip('/')}/{filename}"
111
+
112
+ df_results.to_csv(full_path, index=False)
113
+ print(f"\nSaved full results (with pred_label and prob_past) to {filename}")
114
+
115
+ if out_as_a_df_variable ==True:
116
+ return df_results
117
+
118
+
119
+ if __name__ == "__main__":
120
+ runMode = int(input("Please select a running mode:\n\n1. Txt file path\n2. Csv file path\n\n"))
121
+ if runMode>0 and runMode<5:
122
+ if runMode==1:
123
+ path_to_txt=input("Please provide path to the txt file\n")
124
+ run_tagging(mode="txt_file_path", data_or_path=path_to_txt)
125
+ elif runMode==2:
126
+ path_to_csv=input("Please provide path to the csv file\n")
127
+ run_tagging(mode="csv_file_path", data_or_path=path_to_csv)