flokymind commited on
Commit
4215f17
ยท
1 Parent(s): 8a9a2f7

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +404 -0
model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================
3
+ Arabic Diacritization - mishkala
4
+ ู†ู…ูˆุฐุฌ ุงู„ุชุดูƒูŠู„ ุงู„ุนุฑุจูŠ ุงู„ุชู„ู‚ุงุฆูŠ
5
+ https://huggingface.co/flokymind/mishkala
6
+ ================================================
7
+ ุงู„ู…ุชุทู„ุจุงุช:
8
+ pip install torch pytorch-crf huggingface_hub
9
+ ================================================
10
+ """
11
+
12
+ # โ”€โ”€ ุงู„ู…ุชุทู„ุจุงุช โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
13
+ import subprocess, sys
14
+
15
+ def install(pkg):
16
+ subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-q"])
17
+
18
+ try:
19
+ import torchcrf
20
+ except ImportError:
21
+ install("pytorch-crf")
22
+
23
+ try:
24
+ from huggingface_hub import hf_hub_download
25
+ except ImportError:
26
+ install("huggingface_hub")
27
+
28
+ # โ”€โ”€ ุงู„ุงุณุชูŠุฑุงุฏุงุช โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
29
+ import json, math, re
30
+ from pathlib import Path
31
+ from typing import Dict
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ from torchcrf import CRF
37
+ from huggingface_hub import hf_hub_download
38
+
39
+
40
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
41
+ # 1. ุงู„ุซูˆุงุจุช
42
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
43
+
44
+ REPO_ID = "flokymind/mishkala"
45
+
46
+ DIACRITICS_SET = {
47
+ '\u064e', '\u064b', '\u064f', '\u064c',
48
+ '\u0650', '\u064d', '\u0651', '\u0652',
49
+ }
50
+
51
+ SPECIAL_TOKENS = {'PAD': 0, 'UNK': 1, 'BOS': 2, 'EOS': 3, 'MASK': 4, ' ': 5}
52
+
53
+ DIACRITIC_CLASSES = [
54
+ 'NO_DIACRITIC', 'FATHA', 'FATHATAN', 'DAMMA', 'DAMMATAN',
55
+ 'KASRA', 'KASRATAN', 'SUKUN', 'SHADDA',
56
+ 'SHADDA_FATHA', 'SHADDA_FATHATAN', 'SHADDA_DAMMA',
57
+ 'SHADDA_DAMMATAN', 'SHADDA_KASRA', 'SHADDA_KASRATAN',
58
+ ]
59
+
60
+ DIACRITIC_MAP = {
61
+ 'NO_DIACRITIC': '',
62
+ 'FATHA': '\u064e',
63
+ 'FATHATAN': '\u064b',
64
+ 'DAMMA': '\u064f',
65
+ 'DAMMATAN': '\u064c',
66
+ 'KASRA': '\u0650',
67
+ 'KASRATAN': '\u064d',
68
+ 'SUKUN': '\u0652',
69
+ 'SHADDA': '\u0651',
70
+ 'SHADDA_FATHA': '\u0651\u064e',
71
+ 'SHADDA_FATHATAN': '\u0651\u064b',
72
+ 'SHADDA_DAMMA': '\u0651\u064f',
73
+ 'SHADDA_DAMMATAN': '\u0651\u064c',
74
+ 'SHADDA_KASRA': '\u0651\u0650',
75
+ 'SHADDA_KASRATAN': '\u0651\u064d',
76
+ }
77
+
78
+
79
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
80
+ # 2. ุงู„ุชูˆูƒู†ุงูŠุฒุฑ
81
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
82
+
83
+ class ArabicTokenizer:
84
+ def __init__(self):
85
+ self.char_to_id: Dict[str, int] = {}
86
+ self.id_to_char: Dict[int, str] = {}
87
+ self.vocab_size: int = 0
88
+
89
+ def encode(self, text, max_length=512, padding=True):
90
+ ids = [SPECIAL_TOKENS['BOS']]
91
+ for ch in text:
92
+ if ch in DIACRITICS_SET:
93
+ continue
94
+ ids.append(self.char_to_id.get(ch, SPECIAL_TOKENS['UNK']))
95
+ ids.append(SPECIAL_TOKENS['EOS'])
96
+
97
+ attention_mask = [1] * len(ids)
98
+
99
+ if len(ids) > max_length:
100
+ ids = ids[:max_length]
101
+ attention_mask = attention_mask[:max_length]
102
+ elif padding:
103
+ pad_len = max_length - len(ids)
104
+ ids += [SPECIAL_TOKENS['PAD']] * pad_len
105
+ attention_mask += [0] * pad_len
106
+
107
+ return ids, attention_mask
108
+
109
+ @classmethod
110
+ def load(cls, path):
111
+ data = json.loads(Path(path).read_text(encoding='utf-8'))
112
+ tok = cls()
113
+ tok.char_to_id = data['char_to_id']
114
+ tok.id_to_char = {int(v): k for k, v in data['char_to_id'].items()}
115
+ tok.vocab_size = data['vocab_size']
116
+ print(f"โœ… ุงู„ุชูˆูƒู†ุงูŠุฒุฑ: {tok.vocab_size} ุฑู…ุฒ")
117
+ return tok
118
+
119
+
120
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
121
+ # 3. ู…ูƒูˆู†ุงุช ุงู„ู†ู…ูˆุฐุฌ
122
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
123
+
124
+ class RMSNorm(nn.Module):
125
+ def __init__(self, dim, eps=1e-6):
126
+ super().__init__()
127
+ self.eps = eps
128
+ self.scale = nn.Parameter(torch.ones(dim))
129
+
130
+ def forward(self, x):
131
+ rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
132
+ return self.scale * x / rms
133
+
134
+
135
+ class RotaryEmbedding(nn.Module):
136
+ def __init__(self, dim, max_seq_len=4096):
137
+ super().__init__()
138
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
139
+ self.register_buffer('inv_freq', inv_freq)
140
+ t = torch.arange(max_seq_len).float()
141
+ freqs = torch.outer(t, inv_freq)
142
+ emb = torch.cat([freqs, freqs], dim=-1)
143
+ self.register_buffer('cos_cached', emb.cos())
144
+ self.register_buffer('sin_cached', emb.sin())
145
+
146
+ def forward(self, x, seq_len):
147
+ return (
148
+ self.cos_cached[:seq_len].unsqueeze(0),
149
+ self.sin_cached[:seq_len].unsqueeze(0),
150
+ )
151
+
152
+
153
+ def rotate_half(x):
154
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
155
+ return torch.cat([-x2, x1], dim=-1)
156
+
157
+
158
+ def apply_rope(q, k, cos, sin):
159
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
160
+
161
+
162
+ class SwiGLU(nn.Module):
163
+ def __init__(self, dim, expansion=4):
164
+ super().__init__()
165
+ hidden = int(dim * expansion * 2 / 3)
166
+ hidden = (hidden + 7) // 8 * 8
167
+ self.gate_proj = nn.Linear(dim, hidden, bias=False)
168
+ self.up_proj = nn.Linear(dim, hidden, bias=False)
169
+ self.down_proj = nn.Linear(hidden, dim, bias=False)
170
+
171
+ def forward(self, x):
172
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
173
+
174
+
175
+ class MambaBlock(nn.Module):
176
+ def __init__(self, dim, d_state=16, d_conv=4, expand=2):
177
+ super().__init__()
178
+ self.d_inner = int(dim * expand)
179
+ self.in_proj = nn.Linear(dim, self.d_inner * 2, bias=False)
180
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv,
181
+ padding=d_conv-1, groups=self.d_inner, bias=True)
182
+ self.out_proj = nn.Linear(self.d_inner, dim, bias=False)
183
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
184
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
185
+ A = torch.arange(1, d_state+1).float().unsqueeze(0).expand(self.d_inner, -1)
186
+ self.A_log = nn.Parameter(torch.log(A))
187
+ self.D = nn.Parameter(torch.ones(self.d_inner))
188
+ self.norm = RMSNorm(dim)
189
+
190
+ def ssm(self, x):
191
+ dt = F.softplus(self.dt_proj(self.x_proj(x)[..., :1]))
192
+ return x * self.D + torch.cumsum(x * dt, dim=1) * 0.1
193
+
194
+ def forward(self, x):
195
+ residual = x
196
+ x = self.norm(x)
197
+ xz = self.in_proj(x)
198
+ x_ssm, z = xz.chunk(2, dim=-1)
199
+ x_conv = self.conv1d(x_ssm.transpose(1,2))[..., :x_ssm.shape[1]].transpose(1,2)
200
+ y = self.ssm(F.silu(x_conv)) * F.silu(z)
201
+ return self.out_proj(y) + residual
202
+
203
+
204
+ class TransformerBlock(nn.Module):
205
+ def __init__(self, dim, n_heads, max_len=4096, dropout=0.1):
206
+ super().__init__()
207
+ self.n_heads = n_heads
208
+ self.head_dim = dim // n_heads
209
+ self.q_proj = nn.Linear(dim, dim, bias=False)
210
+ self.k_proj = nn.Linear(dim, dim, bias=False)
211
+ self.v_proj = nn.Linear(dim, dim, bias=False)
212
+ self.o_proj = nn.Linear(dim, dim, bias=False)
213
+ self.rope = RotaryEmbedding(self.head_dim, max_len)
214
+ self.ffn = SwiGLU(dim)
215
+ self.norm1 = RMSNorm(dim)
216
+ self.norm2 = RMSNorm(dim)
217
+ self.dropout = nn.Dropout(dropout)
218
+
219
+ def attention(self, x, mask=None):
220
+ B, L, D = x.shape
221
+ q = self.q_proj(x).view(B,L,self.n_heads,self.head_dim).transpose(1,2)
222
+ k = self.k_proj(x).view(B,L,self.n_heads,self.head_dim).transpose(1,2)
223
+ v = self.v_proj(x).view(B,L,self.n_heads,self.head_dim).transpose(1,2)
224
+ cos, sin = self.rope(x, L)
225
+ cos = cos.unsqueeze(1).expand_as(q)
226
+ sin = sin.unsqueeze(1).expand_as(q)
227
+ q, k = apply_rope(q, k, cos, sin)
228
+ scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.head_dim)
229
+ if mask is not None:
230
+ scores = scores.masked_fill(
231
+ ~mask.unsqueeze(1).unsqueeze(2).bool(), float('-inf')
232
+ )
233
+ attn = self.dropout(F.softmax(scores, dim=-1))
234
+ out = torch.matmul(attn, v).transpose(1,2).contiguous().view(B,L,D)
235
+ return self.o_proj(out)
236
+
237
+ def forward(self, x, mask=None):
238
+ x = x + self.dropout(self.attention(self.norm1(x), mask))
239
+ x = x + self.dropout(self.ffn(self.norm2(x)))
240
+ return x
241
+
242
+
243
+ class ArabicDiacritizerModel(nn.Module):
244
+ def __init__(self, vocab_size=50, dim=320, mamba_layers=4,
245
+ transformer_layers=8, n_heads=8, num_labels=15,
246
+ max_seq_len=4096, dropout=0.15, d_state=16):
247
+ super().__init__()
248
+ self.num_labels = num_labels
249
+ self.embedding = nn.Embedding(vocab_size, dim, padding_idx=0)
250
+ self.emb_norm = RMSNorm(dim)
251
+ self.dropout = nn.Dropout(dropout)
252
+ self.mamba_layers = nn.ModuleList([
253
+ MambaBlock(dim, d_state) for _ in range(mamba_layers)
254
+ ])
255
+ self.transformer_layers = nn.ModuleList([
256
+ TransformerBlock(dim, n_heads, max_seq_len, dropout)
257
+ for _ in range(transformer_layers)
258
+ ])
259
+ self.final_norm = RMSNorm(dim)
260
+ self.classifier = nn.Linear(dim, num_labels)
261
+ self.crf = CRF(num_labels, batch_first=True)
262
+
263
+ def forward(self, input_ids, attention_mask=None, labels=None):
264
+ x = self.dropout(self.emb_norm(self.embedding(input_ids)))
265
+ for m in self.mamba_layers:
266
+ x = m(x)
267
+ for t in self.transformer_layers:
268
+ x = t(x, attention_mask)
269
+ emissions = self.classifier(self.final_norm(x))
270
+ mask = (attention_mask.bool() if attention_mask is not None
271
+ else torch.ones(emissions.shape[:2],
272
+ dtype=torch.bool, device=emissions.device))
273
+ return {
274
+ 'predictions': self.crf.decode(emissions, mask=mask),
275
+ 'emissions': emissions,
276
+ }
277
+
278
+
279
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
280
+ # 4. ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ู…ู† HuggingFace
281
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
282
+
283
+ def load_mishkala(repo_id: str = REPO_ID, device: str = None):
284
+ """
285
+ ุชุญู…ูŠู„ ู†ู…ูˆุฐุฌ ู…ูุดูƒุงู„ุฉ ู…ู† HuggingFace
286
+
287
+ ู…ุซุงู„:
288
+ model, tokenizer, device = load_mishkala()
289
+ """
290
+ if device is None:
291
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
292
+ device = torch.device(device)
293
+
294
+ print(f"๐Ÿ“ฅ ุชุญู…ูŠู„ ู…ูุดูƒุงู„ุฉ ู…ู† {repo_id}...")
295
+
296
+ tokenizer_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json")
297
+ tokenizer = ArabicTokenizer.load(tokenizer_path)
298
+
299
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename="mishkala.pt")
300
+ ckpt = torch.load(ckpt_path, map_location=device)
301
+ model_config = ckpt['config']
302
+ model = ArabicDiacritizerModel(**model_config).to(device)
303
+ model.load_state_dict(ckpt['model_state_dict'])
304
+ model.eval()
305
+
306
+ params = sum(p.numel() for p in model.parameters())
307
+ print(f"โœ… ุงู„ู†ู…ูˆุฐุฌ ุฌุงู‡ุฒ | Step: {ckpt['step']:,} | DER: {ckpt['der']*100:.2f}%")
308
+ print(f" {device} | {params:,} ู…ุนู„ู…ุฉ")
309
+
310
+ return model, tokenizer, device
311
+
312
+
313
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
314
+ # 5. ุฏุงู„ุฉ ุงู„ุชุดูƒูŠู„
315
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
316
+
317
+ def tashkeel(
318
+ text: str,
319
+ model: ArabicDiacritizerModel = None,
320
+ tokenizer: ArabicTokenizer = None,
321
+ device: torch.device = None,
322
+ max_chunk: int = 400,
323
+ ) -> str:
324
+ """
325
+ ุดูƒู‘ู„ ุฃูŠ ู†ุต ุนุฑุจูŠ ุชู„ู‚ุงุฆูŠุงู‹
326
+
327
+ ุงู„ู…ุนุงู…ู„ุงุช:
328
+ text : ุงู„ู†ุต ุงู„ุนุฑุจูŠ ุงู„ู…ุฑุงุฏ ุชุดูƒูŠู„ู‡
329
+ model : ุงู„ู†ู…ูˆุฐุฌ (ูŠูุญู…ูŽู‘ู„ ุชู„ู‚ุงุฆูŠุงู‹ ุฅุฐุง ู„ู… ูŠูุนุทูŽ)
330
+ tokenizer : ุงู„ุชูˆูƒู†ุงูŠุฒุฑ (ูŠูุญู…ูŽู‘ู„ ุชู„ู‚ุงุฆูŠุงู‹ ุฅุฐุง ู„ู… ูŠูุนุทูŽ)
331
+ device : ุงู„ุฌู‡ุงุฒ cuda/cpu
332
+ max_chunk : ุงู„ุญุฏ ุงู„ุฃู‚ุตู‰ ู„ุทูˆู„ ุงู„ู‚ุทุนุฉ ุงู„ูˆุงุญุฏุฉ
333
+
334
+ ุงู„ู…ุฎุฑุฌ:
335
+ ุงู„ู†ุต ู…ุดูƒู‘ู„ุงู‹ ูƒุงู…ู„ุงู‹
336
+
337
+ ู…ุซุงู„:
338
+ model, tokenizer, device = load_mishkala()
339
+ result = tashkeel("ูƒุงู† ุงู„ููŠู„ุณูˆู ูŠุฑู‰ ุฃู† ุงู„ุนู‚ู„ ู…ุฑุขุฉ", model, tokenizer, device)
340
+ print(result)
341
+ # ูƒูŽุงู†ูŽ ุงู„ู’ููŽูŠู’ู„ูŽุณููˆูู ูŠูŽุฑูŽู‰ ุฃูŽู†ูŽู‘ ุงู„ู’ุนูŽู‚ู’ู„ูŽ ู…ูุฑู’ุขุฉูŒ
342
+ """
343
+ # ุชุญู…ูŠู„ ุชู„ู‚ุงุฆูŠ ุฅุฐุง ู„ู… ูŠูุนุทูŽ ู†ู…ูˆุฐุฌ
344
+ global _default_model, _default_tokenizer, _default_device
345
+ if model is None:
346
+ if '_default_model' not in globals():
347
+ _default_model, _default_tokenizer, _default_device = load_mishkala()
348
+ model, tokenizer, device = _default_model, _default_tokenizer, _default_device
349
+
350
+ # ุฅุฒุงู„ุฉ ุงู„ุชุดูƒูŠู„ ุงู„ู…ูˆุฌูˆุฏ
351
+ clean = ''.join(c for c in text if c not in DIACRITICS_SET)
352
+
353
+ # ุชู‚ุณูŠู… ุงู„ู†ุต ุนู„ู‰ ุงู„ุฌู…ู„
354
+ sentences = re.split(r'([.ุŒุŸ!\n])', clean)
355
+ chunks, current = [], ""
356
+ for part in sentences:
357
+ if len(current) + len(part) > max_chunk and current:
358
+ chunks.append(current.strip())
359
+ current = part
360
+ else:
361
+ current += part
362
+ if current.strip():
363
+ chunks.append(current.strip())
364
+
365
+ results = []
366
+ for chunk in chunks:
367
+ if not chunk.strip():
368
+ results.append(chunk)
369
+ continue
370
+
371
+ input_ids, attention_mask = tokenizer.encode(chunk, max_length=512, padding=True)
372
+ ids_t = torch.tensor([input_ids], dtype=torch.long).to(device)
373
+ mask_t = torch.tensor([attention_mask], dtype=torch.long).to(device)
374
+
375
+ with torch.no_grad():
376
+ out = model(ids_t, mask_t)
377
+
378
+ pred_labels = out['predictions'][0]
379
+ chars = [c for c in chunk if c not in DIACRITICS_SET]
380
+ result_chars = []
381
+
382
+ for i, char in enumerate(chars):
383
+ result_chars.append(char)
384
+ label_idx = i + 1
385
+ if label_idx < len(pred_labels):
386
+ diacritic = DIACRITIC_MAP.get(
387
+ DIACRITIC_CLASSES[pred_labels[label_idx]], ''
388
+ )
389
+ result_chars.append(diacritic)
390
+
391
+ results.append(''.join(result_chars))
392
+
393
+ return ''.join(results)
394
+
395
+
396
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
397
+ # 6. ุงู„ุชุดุบูŠู„ ุงู„ู…ุจุงุดุฑ
398
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
399
+
400
+ if __name__ == "__main__":
401
+ model, tokenizer, device = load_mishkala()
402
+
403
+ text = "ุงู„ุฅู†ุณุงู† ุจูŠู† ุงู„ุนู‚ู„ ูˆุงู„ุบุฑูŠุฒุฉ"
404
+ print(f"\nโœจ {tashkeel(text, model, tokenizer, device)}")