Update app.py
Browse files
app.py
CHANGED
|
@@ -49,6 +49,14 @@ train_size = int(0.8 * len(dataset))
|
|
| 49 |
val_size = len(dataset) - train_size
|
| 50 |
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# DataLoader Function (Fix Collate)
|
| 53 |
def collate_fn(batch):
|
| 54 |
keys = batch[0].keys()
|
|
|
|
| 49 |
val_size = len(dataset) - train_size
|
| 50 |
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
| 51 |
|
| 52 |
+
def get_top_words(corpus, n=None):
|
| 53 |
+
vec = CountVectorizer(stop_words='english').fit(corpus)
|
| 54 |
+
bag_of_words = vec.transform(corpus)
|
| 55 |
+
sum_words = bag_of_words.sum(axis=0)
|
| 56 |
+
words_freq = [(word, sum_words[0, idx]) for word, idx in vec.vocabulary_.items()]
|
| 57 |
+
words_freq = sorted(words_freq, key=lambda x: x[1], reverse=True)
|
| 58 |
+
return words_freq[:n]
|
| 59 |
+
|
| 60 |
# DataLoader Function (Fix Collate)
|
| 61 |
def collate_fn(batch):
|
| 62 |
keys = batch[0].keys()
|