Vincent Qin commited on
Commit
4c41360
·
1 Parent(s): e5c1f52

Added DistilBERT test

Browse files
Files changed (1) hide show
  1. app.py +98 -2
app.py CHANGED
@@ -1,4 +1,100 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
+ from datasets import load_dataset, Dataset
4
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
5
+ from datasets import load_metric
6
+ import torch
7
 
8
+ # Load datasets (IMDB and SST2) and combine them
9
+ @st.cache_resource
10
+ def load_datasets():
11
+ imdb = load_dataset('imdb', split='train[:5000]')
12
+ sst2 = load_dataset('glue', 'sst2', split='train[:5000]')
13
+
14
+ # Combine datasets into a single list
15
+ train_list = [{'text': example['text'], 'label': example['label']} for example in imdb] + [{'text': example['sentence'], 'label': example['label']} for example in sst2]
16
+ full_data = Dataset.from_list(train_list)
17
+
18
+ # Split the dataset into train/validation/test
19
+ train_data = full_data.train_test_split(test_size=0.2, seed=42)
20
+ train_data = train_data['train'].train_test_split(test_size=0.25, seed=42) # 60% train, 20% validation, 20% test
21
+ return train_data['train'], train_data['test']
22
+
23
+ train_dataset, val_dataset = load_datasets()
24
+
25
+ # Load the tokenizer and model
26
+ @st.cache_resource
27
+ def load_tokenizer_model():
28
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
29
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
30
+ return tokenizer, model
31
+
32
+ tokenizer, model = load_tokenizer_model()
33
+
34
+ # Preprocess function for tokenization
35
+ def preprocess_function(examples):
36
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
37
+
38
+ # Tokenize datasets
39
+ tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
40
+ tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)
41
+
42
+ # Define the training arguments
43
+ training_args = TrainingArguments(
44
+ output_dir='./results',
45
+ evaluation_strategy='epoch',
46
+ learning_rate=2e-5,
47
+ per_device_train_batch_size=16,
48
+ per_device_eval_batch_size=16,
49
+ num_train_epochs=3,
50
+ weight_decay=0.01,
51
+ logging_dir='./logs',
52
+ )
53
+
54
+ # Load accuracy metric
55
+ metric = load_metric('accuracy')
56
+
57
+ # Function to compute metrics
58
+ def compute_metrics(eval_pred):
59
+ logits, labels = eval_pred
60
+ predictions = np.argmax(logits, axis=-1)
61
+ return metric.compute(predictions=predictions, references=labels)
62
+
63
+ # Initialize the trainer
64
+ trainer = Trainer(
65
+ model=model,
66
+ args=training_args,
67
+ train_dataset=tokenized_train_dataset,
68
+ eval_dataset=tokenized_val_dataset,
69
+ compute_metrics=compute_metrics,
70
+ )
71
+
72
+ # Streamlit UI
73
+ st.title("DistilBERT Sentiment Training and Inference")
74
+
75
+ # Button to start training
76
+ if st.button("Train the Model"):
77
+ st.write("Training the model... This will take some time.")
78
+ trainer.train()
79
+ st.write("Model training complete!")
80
+
81
+ # User input for inference
82
+ st.write("Once the model is trained, you can enter a sentence for sentiment analysis:")
83
+ user_input = st.text_area("Enter a sentence:")
84
+
85
+ # Function to make predictions
86
+ def predict_sentiment(text):
87
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+ logits = outputs.logits
91
+ prediction = torch.argmax(logits, dim=-1).item()
92
+ return "Positive" if prediction == 1 else "Negative"
93
+
94
+ # Button to generate predictions after training
95
+ if st.button("Analyze Sentiment"):
96
+ if user_input.strip():
97
+ result = predict_sentiment(user_input)
98
+ st.write(f"Predicted Sentiment: **{result}**")
99
+ else:
100
+ st.write("Please enter a sentence.")