TLH01 commited on
Commit
15bf0f2
·
verified ·
1 Parent(s): c65725b

creat app_test.py

Browse files
Files changed (1) hide show
  1. app_test.py +52 -0
app_test.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ # Load pipelines
4
+ pipe_bert = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
5
+ pipe_roberta = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment")
6
+
7
+ # Label mapping for RoBERTa
8
+ roberta_label_mapping_dict = {
9
+ 'LABEL_2': 'Positive',
10
+ 'LABEL_1': 'Neutral',
11
+ 'LABEL_0': 'Negative'
12
+ }
13
+
14
+ # Sample input - top 100 review
15
+ top_n = 10
16
+ reviews = train_data['review'][:top_n]
17
+ sentiments = train_data['sentiment'][:top_n]
18
+ data_to_test = dict(zip(reviews, sentiments))
19
+
20
+ # Print header
21
+ print(f"{'Original':<10} | {'DistilBERT':<10} | {'RoBERTa':<10}")
22
+
23
+ # Track accuracy
24
+ bert_correct = 0
25
+ roberta_correct = 0
26
+ total = len(data_to_test)
27
+
28
+ for text, true_label in data_to_test.items():
29
+ pred_bert = pipe_bert(text)[0]
30
+ pred_roberta = pipe_roberta(text, truncation=True)[0]
31
+
32
+ # Normalize labels
33
+ original = true_label.strip().capitalize()
34
+ bert = pred_bert["label"].capitalize()
35
+ roberta = roberta_label_mapping_dict.get(pred_roberta["label"], "Unknown")
36
+
37
+ # Accuracy check
38
+ if bert == original:
39
+ bert_correct += 1
40
+ if roberta == original:
41
+ roberta_correct += 1
42
+
43
+ # Print results
44
+ print(f"{original:<10} | {bert:<10} | {roberta:<10}")
45
+
46
+ # Calculate and print accuracy
47
+ bert_acc = (bert_correct / total) * 100
48
+ roberta_acc = (roberta_correct / total) * 100
49
+
50
+ print(f"\nAccuracy:")
51
+ print(f"DistilBERT: {bert_acc:.2f}%")
52
+ print(f"RoBERTa : {roberta_acc:.2f}%")