Leore42 commited on
Commit
db7bc8a
·
verified ·
1 Parent(s): d129d09

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +90 -3
README.md CHANGED
@@ -1,3 +1,90 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - Leore42/RETAN
5
+ language:
6
+ - en
7
+ ---
8
+
9
+ ```python
10
+ This is an extremely tiny model that summarizes text into 4 words
11
+
12
+ you will need to downlaod the config, tokenizer and model and use this pytho ncode as a starting point:
13
+
14
+ import torch
15
+ import tkinter as tk
16
+ import json
17
+ import torch.nn as nn
18
+ import math
19
+
20
+ class ThemeExtractor(nn.Module):
21
+ def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=1, dropout=0.1):
22
+ super().__init__()
23
+ self.d_model = d_model
24
+ self.embedding = nn.Embedding(vocab_size, d_model)
25
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128, dropout=dropout)
26
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
27
+ self.fc = nn.Linear(d_model, 1)
28
+ self.dropout = nn.Dropout(dropout)
29
+
30
+ def forward(self, x):
31
+ emb = self.embedding(x) * math.sqrt(self.d_model)
32
+ emb = emb.transpose(0, 1)
33
+ enc_out = self.encoder(emb)
34
+ enc_out = enc_out.transpose(0, 1)
35
+ enc_out = self.dropout(enc_out)
36
+ logits = self.fc(enc_out).squeeze(-1)
37
+ return logits
38
+
39
+ def load_model_and_tokenizer():
40
+ with open('config.json', 'r') as f:
41
+ config = json.load(f)
42
+ with open('tokenizer.json', 'r') as f:
43
+ vocab = json.load(f)
44
+ vocab_size = config["vocab_size"]
45
+ model = ThemeExtractor(vocab_size, d_model=64, nhead=4, num_layers=1, dropout=0.2)
46
+ model.load_state_dict(torch.load("theme_extractor.pth"))
47
+ return model, vocab, config
48
+
49
+ def generate_text(model, vocab, config, input_text):
50
+ inv_vocab = {v: k for k, v in vocab.items()}
51
+ max_len = config["max_len"]
52
+ tokens = input_text.lower().split()
53
+ token_ids = [vocab.get(token, vocab["<unk>"]) for token in tokens]
54
+ if len(token_ids) < max_len:
55
+ token_ids += [vocab["<pad>"]] * (max_len - len(token_ids))
56
+ else:
57
+ token_ids = token_ids[:max_len]
58
+ input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
59
+ model.eval()
60
+ with torch.no_grad():
61
+ logits = model(input_tensor)
62
+ probs = torch.sigmoid(logits).squeeze(0)
63
+ topk = torch.topk(probs, 4)
64
+ indices = topk.indices.cpu().numpy()
65
+ selected = sorted(indices, key=lambda i: i)
66
+ theme_words = [tokens[i] for i in selected if i < len(tokens)]
67
+ return ' '.join(theme_words)
68
+
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ model, vocab, config = load_model_and_tokenizer()
71
+
72
+ def on_generate():
73
+ input_text = entry_input.get()
74
+ generated = generate_text(model, vocab, config, input_text)
75
+ label_output.config(text="Generated Themes: " + generated)
76
+
77
+ root = tk.Tk()
78
+ root.title("Theme Extractor")
79
+
80
+ entry_input = tk.Entry(root, width=50)
81
+ entry_input.pack(pady=10)
82
+
83
+ button_generate = tk.Button(root, text="Generate Themes", command=on_generate)
84
+ button_generate.pack(pady=10)
85
+
86
+ label_output = tk.Label(root, text="Generated Themes: ", wraplength=400)
87
+ label_output.pack(pady=10)
88
+
89
+ root.mainloop()
90
+ ```