soroush62 commited on
Commit
563f4d5
·
1 Parent(s): 2b13b28

Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +213 -0
  3. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set random seeds for reproducibility
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ from datasets import load_dataset
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForSequenceClassification,
9
+ DataCollatorWithPadding,
10
+ TrainingArguments,
11
+ Trainer,
12
+ EarlyStoppingCallback
13
+ )
14
+ from transformers import TextClassificationPipeline
15
+ from sklearn.metrics import accuracy_score, f1_score
16
+ from transformers_interpret import SequenceClassificationExplainer
17
+ from transformers import pipeline
18
+ import gradio as gr
19
+
20
+ SEED = 42
21
+ random.seed(SEED)
22
+ np.random.seed(SEED)
23
+ torch.manual_seed(SEED)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed_all(SEED)
26
+
27
+ USE_MPS = torch.backends.mps.is_available()
28
+ device = torch.device("mps" if USE_MPS else "cpu")
29
+ print("Using device:", device)
30
+
31
+ # Load the ag_news dataset
32
+ raw = load_dataset("SetFit/ag_news")
33
+ print(raw)
34
+
35
+
36
+ # Load BERT tokenizer
37
+ MODEL_NAME = "bert-base-uncased"
38
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
39
+
40
+ # Tokenization function
41
+ def tokenize_fn(examples):
42
+ return tokenizer(examples["text"], truncation=True, max_length=128)
43
+
44
+ cols_to_remove = [c for c in raw["train"].column_names if c not in ("label",)]
45
+
46
+ # Apply tokenization to the dataset
47
+ tokenized = raw.map(tokenize_fn, batched=True, remove_columns=cols_to_remove)
48
+
49
+
50
+ # Remove original text column to avoid issues during batching
51
+ if "text" in tokenized["train"].column_names:
52
+ tokenized = tokenized.remove_columns(["text"])
53
+
54
+ # Set dataset format to PyTorch tensors
55
+ tokenized.set_format("torch")
56
+
57
+ # Shuffle and split the training dataset to create a validation set
58
+ train_dataset = tokenized["train"].shuffle(seed=SEED)
59
+ val_split = train_dataset.train_test_split(test_size=5000, seed=SEED)
60
+ train_dataset = val_split["train"]
61
+ eval_dataset = val_split["test"]
62
+
63
+ print(train_dataset)
64
+
65
+ # Load pre-trained BERT model for sequence classification
66
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4)
67
+
68
+ # Create a data collator that dynamically pads input sequences in each batch
69
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
70
+
71
+
72
+ # Define a metrics computation function using scikit-learn
73
+ def compute_metrics(eval_pred):
74
+ logits, labels = eval_pred
75
+ # Convert logits to predicted class indices
76
+ preds = np.argmax(logits, axis=-1)
77
+
78
+ # Compute accuracy and F1 score using scikit-learn
79
+ acc = accuracy_score(labels, preds)
80
+ f1 = f1_score(labels, preds, average='macro')
81
+
82
+ return {"accuracy": acc, "f1_macro": f1}
83
+
84
+ # Define training arguments
85
+ training_args = TrainingArguments(
86
+ output_dir="./results",
87
+ eval_strategy="epoch",
88
+ save_strategy="epoch",
89
+ logging_strategy="epoch",
90
+ #report_to=[], # <- disable all integrations (no wandb, no tensorboard)
91
+ per_device_train_batch_size=8,
92
+ per_device_eval_batch_size=8,
93
+ num_train_epochs=3,
94
+ learning_rate=2e-5,
95
+ weight_decay=0.1,
96
+ warmup_steps=100,
97
+ load_best_model_at_end=True,
98
+ metric_for_best_model="eval_loss",
99
+ greater_is_better=False,
100
+ save_total_limit=3,
101
+ fp16=torch.cuda.is_available(),
102
+ dataloader_drop_last=False,
103
+ gradient_accumulation_steps=1,
104
+ seed=SEED,
105
+ )
106
+
107
+ # Create Trainer instance with early stopping
108
+ trainer = Trainer(
109
+ model=model,
110
+ args=training_args,
111
+ train_dataset=train_dataset,
112
+ eval_dataset=eval_dataset,
113
+ tokenizer=tokenizer,
114
+ data_collator=data_collator,
115
+ compute_metrics=compute_metrics,
116
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
117
+ )
118
+
119
+ # Start model training
120
+ trainer.train()
121
+
122
+ # Save the fine-tuned model
123
+ trainer.save_model('my-fine-tuned-bert')
124
+
125
+ # Save the tokenizer
126
+ tokenizer.save_pretrained('my-fine-tuned-bert')
127
+
128
+ # Load the fine-tuned model and tokenizer
129
+ new_model = AutoModelForSequenceClassification.from_pretrained('my-fine-tuned-bert')
130
+ new_tokenizer = AutoTokenizer.from_pretrained('my-fine-tuned-bert')
131
+
132
+ # Create a text classification pipeline
133
+ classifier = TextClassificationPipeline(
134
+ model=new_model,
135
+ tokenizer=new_tokenizer, )
136
+
137
+ # Define label mapping
138
+ label_mapping = {
139
+ 0: 'World',
140
+ 1: 'Sports',
141
+ 2: 'Business',
142
+ 3: 'Sci/Tech'
143
+ }
144
+
145
+ # Test the classifier on a sample sentence
146
+ sample_text = "This movie was good"
147
+ result = classifier(sample_text)
148
+
149
+ # Map the predicted label to a meaningful sentiment
150
+ mapped_result = {
151
+ 'label': label_mapping[int(result[0]['label'].split('_')[1])],
152
+ 'score': result[0]['score']
153
+ }
154
+
155
+ print(mapped_result)
156
+
157
+ MODEL_ID = "my-fine-tuned-bert"
158
+
159
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
160
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
161
+ explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
162
+
163
+ label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
164
+
165
+ device = 0 if torch.cuda.is_available() else -1
166
+ clf = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
167
+
168
+ def predict(text: str):
169
+ text = (text or "").strip()
170
+ if not text:
171
+ return {}
172
+ out = clf(text, truncation=True)
173
+ if isinstance(out, list) and isinstance(out[0], list):
174
+ out = out[0]
175
+ results = {}
176
+ for o in sorted(out, key=lambda x: -x["score"]):
177
+ idx = int(o["label"].split("_")[1])
178
+ results[label_names[idx]] = float(o["score"])
179
+ return results
180
+
181
+ # Build script-free HTML so it renders in Gradio pages
182
+ def explain_html(text: str) -> str:
183
+ text = (text or "").strip()
184
+ if not text:
185
+ return "<i>Enter text to see highlighted words.</i>"
186
+ atts = explainer(text)
187
+ toks = [t for t, _ in atts]
188
+ scores = np.abs([s for _, s in atts])
189
+ smin, smax = float(np.min(scores)), float(np.max(scores))
190
+ scores = (scores - smin) / (smax - smin + 1e-8)
191
+ spans = [
192
+ f"<span style='background: rgba(255,0,0,{0.15+0.85*s:.2f});"
193
+ f"padding:2px 3px; margin:1px; border-radius:4px; display:inline-block'>{tok}</span>"
194
+ for tok, s in zip(toks, scores)
195
+ ]
196
+ return "<div style='line-height:2'>" + " ".join(spans) + "</div>"
197
+
198
+ def predict_and_explain(text: str):
199
+ return predict(text), explain_html(text)
200
+
201
+ demo = gr.Interface(
202
+ fn=predict_and_explain,
203
+ inputs=gr.Textbox(lines=3, label="Enter news headline"),
204
+ outputs=[
205
+ gr.Label(num_top_classes=4, label="Predicted topic"),
206
+ gr.HTML(label="Important-word highlights"),
207
+ ],
208
+ title="AG News Topic Classifier (BERT-base)",
209
+ description="Shows predicted topic and highlights words that influenced the decision."
210
+ )
211
+
212
+ if __name__ == "__main__":
213
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ datasets
4
+ transformers
5
+ accelerate
6
+ scikit-learn
7
+ transformers-interpret
8
+ gradio