Spaces:
Running
Running
Update pages/17_RNN_News.py
Browse files- pages/17_RNN_News.py +13 -17
pages/17_RNN_News.py
CHANGED
|
@@ -51,25 +51,22 @@ def load_data():
|
|
| 51 |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
| 52 |
vocab.set_default_index(vocab["<unk>"])
|
| 53 |
|
| 54 |
-
|
| 55 |
-
global text_pipeline, label_pipeline
|
| 56 |
-
text_pipeline = lambda x: vocab(tokenizer(x))
|
| 57 |
-
label_pipeline = lambda x: int(x) - 1
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return vocab, train_loader, valid_loader, test_loader
|
| 73 |
|
| 74 |
# Function to train the network
|
| 75 |
def train_network(net, iterator, optimizer, criterion, epochs):
|
|
@@ -116,7 +113,6 @@ def evaluate_network(net, iterator, criterion):
|
|
| 116 |
|
| 117 |
# Load data
|
| 118 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 119 |
-
vocab, train_loader, valid_loader, test_loader = load_data()
|
| 120 |
|
| 121 |
# Streamlit interface
|
| 122 |
st.title("RNN for Text Classification on AG News Dataset")
|
|
|
|
| 51 |
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
|
| 52 |
vocab.set_default_index(vocab["<unk>"])
|
| 53 |
|
| 54 |
+
return vocab, tokenizer, list(train_iter), list(test_iter)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
# Initialize global pipelines
|
| 57 |
+
vocab, tokenizer, train_dataset, test_dataset = load_data()
|
| 58 |
+
text_pipeline = lambda x: vocab(tokenizer(x))
|
| 59 |
+
label_pipeline = lambda x: int(x) - 1
|
| 60 |
|
| 61 |
+
# Create DataLoaders
|
| 62 |
+
train_size = int(0.8 * len(train_dataset))
|
| 63 |
+
valid_size = len(train_dataset) - train_size
|
| 64 |
+
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])
|
| 65 |
|
| 66 |
+
BATCH_SIZE = 64
|
| 67 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
| 68 |
+
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
| 69 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# Function to train the network
|
| 72 |
def train_network(net, iterator, optimizer, criterion, epochs):
|
|
|
|
| 113 |
|
| 114 |
# Load data
|
| 115 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 116 |
|
| 117 |
# Streamlit interface
|
| 118 |
st.title("RNN for Text Classification on AG News Dataset")
|