bashyaldhiraj2067 commited on
Commit
f97a63c
·
verified ·
1 Parent(s): aa65d0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -141
app.py CHANGED
@@ -1,199 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import gradio as gr
5
- from torch.utils.data import Dataset
6
- from transformers import PreTrainedModel, PretrainedConfig, Trainer, TrainingArguments
7
- from datasets import load_dataset
8
- import numpy as np
9
-
10
- # =====================
11
- # 1. Load Dataset Subsets
12
- # =====================
13
- dataset = load_dataset("bashyaldhiraj2067/500k_copy_error_dataset")
14
- train_subset = dataset["train"].select(range(int(len(dataset["train"]) * 0.1)))
15
- test_subset = dataset["test"].select(range(int(len(dataset["test"]) * 0.1)))
16
- print(f"Subset train size: {len(train_subset)}")
17
- print(f"Subset test size: {len(test_subset)}")
18
-
19
- # =====================
20
- # 2. Tokenizer
21
- # =====================
22
  special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
23
- nepali_chars = list("अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह्ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅंःॊॅऒऽॉड़ॐ॥ऑऱफ़ढ़")
 
 
 
24
  char_vocab = special_tokens + nepali_chars
25
- char2id = {char: idx for idx, char in enumerate(char_vocab)}
26
- id2char = {idx: char for char, idx in char2id.items()}
27
- vocab_size = len(char2id)
28
 
29
  class CharTokenizer:
30
- def __init__(self, char2id, id2char, vocab_size):
31
- self.char2id = char2id
32
- self.id2char = id2char
33
  self.pad_token_id = char2id["<pad>"]
34
  self.unk_token_id = char2id["<unk>"]
35
  self.bos_token_id = char2id["<s>"]
36
  self.eos_token_id = char2id["</s>"]
37
- self.vocab_size = vocab_size
38
 
39
  def encode(self, text, max_length=128):
40
- ids = [self.char2id.get(ch, self.unk_token_id) for ch in text]
41
  ids = ids[:max_length]
42
  return ids + [self.pad_token_id] * (max_length - len(ids))
43
 
44
  def decode(self, ids):
45
- return ''.join([self.id2char.get(i, '') for i in ids if i != self.pad_token_id])
46
-
47
- def __call__(self, text, text_target=None, max_length=128):
48
- input_ids = self.encode(text, max_length)
49
- input_ids = torch.clamp(torch.tensor(input_ids), max=self.vocab_size - 1).tolist()
50
- result = {"input_ids": input_ids, "attention_mask": [1 if i != self.pad_token_id else 0 for i in input_ids]}
51
- if text_target:
52
- labels = self.encode(text_target, max_length)
53
- result["labels"] = labels
54
- return result
55
-
56
- tokenizer = CharTokenizer(char2id, id2char, vocab_size=vocab_size)
57
-
58
- # =====================
59
- # 3. Dataset
60
- # =====================
61
- class CopyDataset(Dataset):
62
- def __init__(self, data, tokenizer, max_length=128):
63
- self.data = data
64
- self.tokenizer = tokenizer
65
- self.max_length = max_length
66
-
67
- def __len__(self):
68
- return len(self.data)
69
-
70
- def __getitem__(self, idx):
71
- noisy = self.data[idx]['incorrect']
72
- clean = self.data[idx]['correct']
73
- return self.tokenizer(noisy, text_target=clean, max_length=self.max_length)
74
-
75
- train_dataset = CopyDataset(train_subset, tokenizer)
76
- eval_dataset = CopyDataset(test_subset, tokenizer)
77
-
78
- # =====================
79
- # 4. Transformer with Copy Mechanism
80
- # =====================
81
  class TransformerCopyConfig(PretrainedConfig):
82
- def __init__(self, vocab_size=len(char2id), **kwargs):
 
83
  super().__init__(**kwargs)
84
  self.vocab_size = vocab_size
85
 
86
- # --- Model Components ---
87
  class PositionalEncoding(nn.Module):
88
  def __init__(self, d_model, max_len=512):
89
  super().__init__()
90
  pe = torch.zeros(max_len, d_model)
91
  position = torch.arange(0, max_len).unsqueeze(1)
92
- div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
93
- pe[:, 0::2] = torch.sin(position * div_term)
94
- pe[:, 1::2] = torch.cos(position * div_term)
95
- self.register_buffer('pe', pe.unsqueeze(0))
 
 
96
 
97
  def forward(self, x):
98
- return x + self.pe[:, :x.size(1)]
99
 
100
  class TransformerCopyModel(nn.Module):
101
- def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_ff=512, dropout=0.1):
102
  super().__init__()
103
  self.embedding = nn.Embedding(vocab_size, d_model)
104
- self.positional_encoding = PositionalEncoding(d_model)
105
-
106
- encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_ff, dropout)
107
- decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
108
-
109
- self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
110
- self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
111
-
112
- self.copy_attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
113
- self.copy_gate = nn.Linear(d_model * 2, 1)
114
 
115
- self.output_layer = nn.Linear(d_model, vocab_size)
 
116
 
117
- def forward(self, input_ids, attention_mask=None, labels=None):
118
- src = input_ids
119
- tgt = labels[:, :-1]
120
- tgt_y = labels[:, 1:]
121
 
122
- src_embed = self.embedding(src)
123
- tgt_embed = self.embedding(tgt)
124
- src_embed = self.positional_encoding(src_embed)
125
- tgt_embed = self.positional_encoding(tgt_embed)
126
 
127
- src_mask = (src == tokenizer.pad_token_id)
128
- tgt_mask = (tgt == tokenizer.pad_token_id)
 
129
 
130
- memory = self.encoder(src_embed.transpose(0, 1), src_key_padding_mask=src_mask)
131
- output = self.decoder(
132
- tgt_embed.transpose(0, 1),
133
- memory,
134
- tgt_key_padding_mask=tgt_mask,
135
- memory_key_padding_mask=src_mask
136
  )
137
 
138
- attn_output, attn_weights = self.copy_attention(output, memory, memory, key_padding_mask=src_mask)
139
- concat = torch.cat([output, attn_output], dim=-1)
140
- copy_prob = torch.sigmoid(self.copy_gate(concat))
141
 
142
- gen_logits = self.output_layer(output)
143
- gen_probs = F.softmax(gen_logits, dim=-1)
144
-
145
- loss = F.cross_entropy(
146
- gen_logits.transpose(0, 1).reshape(-1, gen_logits.size(-1)),
147
- tgt_y.reshape(-1),
148
- ignore_index=tokenizer.pad_token_id
149
- ) if labels is not None else None
150
-
151
- return {"loss": loss, "logits": gen_logits.transpose(0, 1)}
152
-
153
- # --- HF Wrapper ---
154
  class TransformerCopyHF(PreTrainedModel):
155
  config_class = TransformerCopyConfig
 
156
  def __init__(self, config):
157
  super().__init__(config)
158
  self.model = TransformerCopyModel(config.vocab_size)
159
 
160
- def forward(self, input_ids, attention_mask=None, labels=None):
161
- return self.model(input_ids, attention_mask, labels)
162
 
163
- model = TransformerCopyHF.from_pretrained("bashyaldhiraj2067/remove_copy_transformer")
 
 
 
 
 
 
 
164
  model.eval()
165
 
166
- # =====================
167
- # 5. Inference Function
168
- # =====================
169
- def generate_clean_text(input_text, max_length=128):
170
- model_input = tokenizer.encode(input_text, max_length=max_length)
171
- input_ids = torch.tensor([model_input])
172
- # Create dummy target input (just start token)
173
- decoder_input = torch.tensor([[tokenizer.bos_token_id]])
 
 
 
 
 
 
 
174
  output_tokens = []
175
- for _ in range(max_length):
176
- with torch.no_grad():
177
- out = model(input_ids=input_ids, labels=torch.cat([decoder_input, torch.zeros((1, 1), dtype=torch.long)], dim=1))
178
- next_token_logits = out["logits"][:, -1, :]
179
- next_token = torch.argmax(next_token_logits, dim=-1)
180
 
181
- next_token_id = next_token.item()
182
-
183
- if next_token_id == tokenizer.pad_token_id:
 
 
 
184
  break
185
- output_tokens.append(next_token_id)
186
- decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
 
187
 
188
  return tokenizer.decode(output_tokens)
189
 
 
 
 
 
 
190
 
191
- # Gradio Interface Setup
192
- iface = gr.Interface(
193
- fn=generate_clean_text,
194
- inputs=gr.Textbox(label="Noisy Text"),
195
- outputs=gr.Textbox(label="Cleaned Text"),
196
- live=True
197
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- iface.launch(debug=True)
 
1
+ # import torch
2
+ # import torch.nn as nn
3
+ # import torch.nn.functional as F
4
+ # import gradio as gr
5
+ # from torch.utils.data import Dataset
6
+ # from transformers import PreTrainedModel, PretrainedConfig, Trainer, TrainingArguments
7
+ # from datasets import load_dataset
8
+ # import numpy as np
9
+
10
+ # # =====================
11
+ # # 1. Load Dataset Subsets
12
+ # # =====================
13
+ # dataset = load_dataset("bashyaldhiraj2067/500k_copy_error_dataset")
14
+ # train_subset = dataset["train"].select(range(int(len(dataset["train"]) * 0.1)))
15
+ # test_subset = dataset["test"].select(range(int(len(dataset["test"]) * 0.1)))
16
+ # print(f"Subset train size: {len(train_subset)}")
17
+ # print(f"Subset test size: {len(test_subset)}")
18
+
19
+ # # =====================
20
+ # # 2. Tokenizer
21
+ # # =====================
22
+ # special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
23
+ # nepali_chars = list("अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह्ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅंःॊॅऒऽॉड़ॐ॥ऑऱफ़ढ़")
24
+ # char_vocab = special_tokens + nepali_chars
25
+ # char2id = {char: idx for idx, char in enumerate(char_vocab)}
26
+ # id2char = {idx: char for char, idx in char2id.items()}
27
+ # vocab_size = len(char2id)
28
+
29
+ # class CharTokenizer:
30
+ # def __init__(self, char2id, id2char, vocab_size):
31
+ # self.char2id = char2id
32
+ # self.id2char = id2char
33
+ # self.pad_token_id = char2id["<pad>"]
34
+ # self.unk_token_id = char2id["<unk>"]
35
+ # self.bos_token_id = char2id["<s>"]
36
+ # self.eos_token_id = char2id["</s>"]
37
+ # self.vocab_size = vocab_size
38
+
39
+ # def encode(self, text, max_length=128):
40
+ # ids = [self.char2id.get(ch, self.unk_token_id) for ch in text]
41
+ # ids = ids[:max_length]
42
+ # return ids + [self.pad_token_id] * (max_length - len(ids))
43
+
44
+ # def decode(self, ids):
45
+ # return ''.join([self.id2char.get(i, '') for i in ids if i != self.pad_token_id])
46
+
47
+ # def __call__(self, text, text_target=None, max_length=128):
48
+ # input_ids = self.encode(text, max_length)
49
+ # input_ids = torch.clamp(torch.tensor(input_ids), max=self.vocab_size - 1).tolist()
50
+ # result = {"input_ids": input_ids, "attention_mask": [1 if i != self.pad_token_id else 0 for i in input_ids]}
51
+ # if text_target:
52
+ # labels = self.encode(text_target, max_length)
53
+ # result["labels"] = labels
54
+ # return result
55
+
56
+ # tokenizer = CharTokenizer(char2id, id2char, vocab_size=vocab_size)
57
+
58
+ # # =====================
59
+ # # 3. Dataset
60
+ # # =====================
61
+ # class CopyDataset(Dataset):
62
+ # def __init__(self, data, tokenizer, max_length=128):
63
+ # self.data = data
64
+ # self.tokenizer = tokenizer
65
+ # self.max_length = max_length
66
+
67
+ # def __len__(self):
68
+ # return len(self.data)
69
+
70
+ # def __getitem__(self, idx):
71
+ # noisy = self.data[idx]['incorrect']
72
+ # clean = self.data[idx]['correct']
73
+ # return self.tokenizer(noisy, text_target=clean, max_length=self.max_length)
74
+
75
+ # train_dataset = CopyDataset(train_subset, tokenizer)
76
+ # eval_dataset = CopyDataset(test_subset, tokenizer)
77
+
78
+ # # =====================
79
+ # # 4. Transformer with Copy Mechanism
80
+ # # =====================
81
+ # class TransformerCopyConfig(PretrainedConfig):
82
+ # def __init__(self, vocab_size=len(char2id), **kwargs):
83
+ # super().__init__(**kwargs)
84
+ # self.vocab_size = vocab_size
85
+
86
+ # # --- Model Components ---
87
+ # class PositionalEncoding(nn.Module):
88
+ # def __init__(self, d_model, max_len=512):
89
+ # super().__init__()
90
+ # pe = torch.zeros(max_len, d_model)
91
+ # position = torch.arange(0, max_len).unsqueeze(1)
92
+ # div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
93
+ # pe[:, 0::2] = torch.sin(position * div_term)
94
+ # pe[:, 1::2] = torch.cos(position * div_term)
95
+ # self.register_buffer('pe', pe.unsqueeze(0))
96
+
97
+ # def forward(self, x):
98
+ # return x + self.pe[:, :x.size(1)]
99
+
100
+ # class TransformerCopyModel(nn.Module):
101
+ # def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_ff=512, dropout=0.1):
102
+ # super().__init__()
103
+ # self.embedding = nn.Embedding(vocab_size, d_model)
104
+ # self.positional_encoding = PositionalEncoding(d_model)
105
+
106
+ # encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_ff, dropout)
107
+ # decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
108
+
109
+ # self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
110
+ # self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
111
+
112
+ # self.copy_attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
113
+ # self.copy_gate = nn.Linear(d_model * 2, 1)
114
+
115
+ # self.output_layer = nn.Linear(d_model, vocab_size)
116
+
117
+ # def forward(self, input_ids, attention_mask=None, labels=None):
118
+ # src = input_ids
119
+ # tgt = labels[:, :-1]
120
+ # tgt_y = labels[:, 1:]
121
+
122
+ # src_embed = self.embedding(src)
123
+ # tgt_embed = self.embedding(tgt)
124
+ # src_embed = self.positional_encoding(src_embed)
125
+ # tgt_embed = self.positional_encoding(tgt_embed)
126
+
127
+ # src_mask = (src == tokenizer.pad_token_id)
128
+ # tgt_mask = (tgt == tokenizer.pad_token_id)
129
+
130
+ # memory = self.encoder(src_embed.transpose(0, 1), src_key_padding_mask=src_mask)
131
+ # output = self.decoder(
132
+ # tgt_embed.transpose(0, 1),
133
+ # memory,
134
+ # tgt_key_padding_mask=tgt_mask,
135
+ # memory_key_padding_mask=src_mask
136
+ # )
137
+
138
+ # attn_output, attn_weights = self.copy_attention(output, memory, memory, key_padding_mask=src_mask)
139
+ # concat = torch.cat([output, attn_output], dim=-1)
140
+ # copy_prob = torch.sigmoid(self.copy_gate(concat))
141
+
142
+ # gen_logits = self.output_layer(output)
143
+ # gen_probs = F.softmax(gen_logits, dim=-1)
144
+
145
+ # loss = F.cross_entropy(
146
+ # gen_logits.transpose(0, 1).reshape(-1, gen_logits.size(-1)),
147
+ # tgt_y.reshape(-1),
148
+ # ignore_index=tokenizer.pad_token_id
149
+ # ) if labels is not None else None
150
+
151
+ # return {"loss": loss, "logits": gen_logits.transpose(0, 1)}
152
+
153
+ # # --- HF Wrapper ---
154
+ # class TransformerCopyHF(PreTrainedModel):
155
+ # config_class = TransformerCopyConfig
156
+ # def __init__(self, config):
157
+ # super().__init__(config)
158
+ # self.model = TransformerCopyModel(config.vocab_size)
159
+
160
+ # def forward(self, input_ids, attention_mask=None, labels=None):
161
+ # return self.model(input_ids, attention_mask, labels)
162
+
163
+ # model = TransformerCopyHF.from_pretrained("bashyaldhiraj2067/remove_copy_transformer")
164
+ # model.eval()
165
+
166
+ # # =====================
167
+ # # 5. Inference Function
168
+ # # =====================
169
+ # def generate_clean_text(input_text, max_length=128):
170
+ # model_input = tokenizer.encode(input_text, max_length=max_length)
171
+ # input_ids = torch.tensor([model_input])
172
+ # # Create dummy target input (just start token)
173
+ # decoder_input = torch.tensor([[tokenizer.bos_token_id]])
174
+ # output_tokens = []
175
+ # for _ in range(max_length):
176
+ # with torch.no_grad():
177
+ # out = model(input_ids=input_ids, labels=torch.cat([decoder_input, torch.zeros((1, 1), dtype=torch.long)], dim=1))
178
+ # next_token_logits = out["logits"][:, -1, :]
179
+ # next_token = torch.argmax(next_token_logits, dim=-1)
180
+
181
+ # next_token_id = next_token.item()
182
+
183
+ # if next_token_id == tokenizer.pad_token_id:
184
+ # break
185
+ # output_tokens.append(next_token_id)
186
+ # decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
187
+
188
+ # return tokenizer.decode(output_tokens)
189
+
190
+
191
+ # # Gradio Interface Setup
192
+ # iface = gr.Interface(
193
+ # fn=generate_clean_text,
194
+ # inputs=gr.Textbox(label="Noisy Text"),
195
+ # outputs=gr.Textbox(label="Cleaned Text"),
196
+ # live=True
197
+ # )
198
+
199
+ # iface.launch(debug=True)
200
  import torch
201
  import torch.nn as nn
202
  import torch.nn.functional as F
203
  import gradio as gr
204
+ from transformers import PreTrainedModel, PretrainedConfig
205
+
206
+ # =========================================================
207
+ # 1. Tokenizer (CUSTOM – REQUIRED)
208
+ # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
209
  special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
210
+ nepali_chars = list(
211
+ "अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह"
212
+ "ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅॊऒऽॉड़ॐ॥ऑऱफ़ढ़"
213
+ )
214
  char_vocab = special_tokens + nepali_chars
215
+ char2id = {c: i for i, c in enumerate(char_vocab)}
216
+ id2char = {i: c for c, i in char2id.items()}
 
217
 
218
  class CharTokenizer:
219
+ def __init__(self):
 
 
220
  self.pad_token_id = char2id["<pad>"]
221
  self.unk_token_id = char2id["<unk>"]
222
  self.bos_token_id = char2id["<s>"]
223
  self.eos_token_id = char2id["</s>"]
224
+ self.vocab_size = len(char2id)
225
 
226
  def encode(self, text, max_length=128):
227
+ ids = [char2id.get(ch, self.unk_token_id) for ch in text]
228
  ids = ids[:max_length]
229
  return ids + [self.pad_token_id] * (max_length - len(ids))
230
 
231
  def decode(self, ids):
232
+ return "".join(
233
+ id2char[i] for i in ids if i != self.pad_token_id
234
+ )
235
+
236
+ tokenizer = CharTokenizer()
237
+
238
+ # =========================================================
239
+ # 2. Model Definition (CUSTOM – REQUIRED)
240
+ # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  class TransformerCopyConfig(PretrainedConfig):
242
+ model_type = "transformer_copy"
243
+ def __init__(self, vocab_size=tokenizer.vocab_size, **kwargs):
244
  super().__init__(**kwargs)
245
  self.vocab_size = vocab_size
246
 
 
247
  class PositionalEncoding(nn.Module):
248
  def __init__(self, d_model, max_len=512):
249
  super().__init__()
250
  pe = torch.zeros(max_len, d_model)
251
  position = torch.arange(0, max_len).unsqueeze(1)
252
+ div = torch.exp(
253
+ torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model)
254
+ )
255
+ pe[:, 0::2] = torch.sin(position * div)
256
+ pe[:, 1::2] = torch.cos(position * div)
257
+ self.register_buffer("pe", pe.unsqueeze(0))
258
 
259
  def forward(self, x):
260
+ return x + self.pe[:, : x.size(1)]
261
 
262
  class TransformerCopyModel(nn.Module):
263
+ def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4):
264
  super().__init__()
265
  self.embedding = nn.Embedding(vocab_size, d_model)
266
+ self.pos = PositionalEncoding(d_model)
 
 
 
 
 
 
 
 
 
267
 
268
+ enc_layer = nn.TransformerEncoderLayer(d_model, nhead)
269
+ dec_layer = nn.TransformerDecoderLayer(d_model, nhead)
270
 
271
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers)
272
+ self.decoder = nn.TransformerDecoder(dec_layer, num_layers)
 
 
273
 
274
+ self.fc = nn.Linear(d_model, vocab_size)
 
 
 
275
 
276
+ def forward(self, src, tgt):
277
+ src_emb = self.pos(self.embedding(src))
278
+ tgt_emb = self.pos(self.embedding(tgt))
279
 
280
+ memory = self.encoder(src_emb.transpose(0, 1))
281
+ out = self.decoder(
282
+ tgt_emb.transpose(0, 1), memory
 
 
 
283
  )
284
 
285
+ return self.fc(out.transpose(0, 1))
 
 
286
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  class TransformerCopyHF(PreTrainedModel):
288
  config_class = TransformerCopyConfig
289
+
290
  def __init__(self, config):
291
  super().__init__(config)
292
  self.model = TransformerCopyModel(config.vocab_size)
293
 
294
+ def forward(self, input_ids, decoder_input_ids):
295
+ return self.model(input_ids, decoder_input_ids)
296
 
297
+ # =========================================================
298
+ # 3. Load Weights from Hugging Face
299
+ # =========================================================
300
+ device = "cuda" if torch.cuda.is_available() else "cpu"
301
+
302
+ model = TransformerCopyHF.from_pretrained(
303
+ "bashyaldhiraj2067/remove_copy_transformer"
304
+ ).to(device)
305
  model.eval()
306
 
307
+ # =========================================================
308
+ # 4. Inference Function
309
+ # =========================================================
310
+ @torch.no_grad()
311
+ def generate_clean_text(text, max_len=128):
312
+ src = torch.tensor(
313
+ [tokenizer.encode(text, max_len)],
314
+ device=device
315
+ )
316
+
317
+ tgt = torch.tensor(
318
+ [[tokenizer.bos_token_id]],
319
+ device=device
320
+ )
321
+
322
  output_tokens = []
 
 
 
 
 
323
 
324
+ for _ in range(max_len):
325
+ logits = model(src, tgt)
326
+ next_token = torch.argmax(logits[:, -1], dim=-1)
327
+
328
+ token_id = next_token.item()
329
+ if token_id == tokenizer.pad_token_id:
330
  break
331
+
332
+ output_tokens.append(token_id)
333
+ tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
334
 
335
  return tokenizer.decode(output_tokens)
336
 
337
+ # =========================================================
338
+ # 5. Gradio UI
339
+ # =========================================================
340
+ with gr.Blocks(title="Nepali GEC – Copy Transformer") as demo:
341
+ gr.Markdown("## 🇳🇵 Nepali Grammatical Error Correction")
342
 
343
+ inp = gr.Textbox(
344
+ label="Noisy / Incorrect Text",
345
+ lines=4,
346
+ placeholder="यहाँ गलत नेपाली वाक्य लेख्नुहोस्"
347
+ )
348
+
349
+ out = gr.Textbox(
350
+ label="Corrected Text",
351
+ lines=4
352
+ )
353
+
354
+ btn = gr.Button("Correct")
355
+
356
+ btn.click(
357
+ fn=generate_clean_text,
358
+ inputs=inp,
359
+ outputs=out
360
+ )
361
+
362
+ demo.launch()
363