al1808th commited on
Commit
89195b1
·
1 Parent(s): 8887a7d

Improve macronizer UI and syllable classification output

Browse files
Files changed (1) hide show
  1. app.py +345 -24
app.py CHANGED
@@ -1,40 +1,361 @@
 
 
 
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForTokenClassification
 
 
 
 
4
 
5
- MODEL_ID = "Ericu950/macronizer_mini"
 
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
8
  model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
 
 
 
9
 
10
  id2label = model.config.id2label
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def macronize(text):
14
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
15
-
16
- with torch.no_grad():
17
- outputs = model(**inputs)
18
-
19
- logits = outputs.logits
20
- predictions = torch.argmax(logits, dim=-1)[0]
21
 
22
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
 
 
 
 
 
 
 
 
23
 
24
- # Pair tokens with labels
25
- result = []
26
- for token, pred in zip(tokens, predictions):
27
- label = id2label[int(pred)]
28
- result.append(f"{token}:{label}")
 
 
 
 
 
 
29
 
30
- return " ".join(result)
 
 
31
 
 
 
32
 
33
- iface = gr.Interface(
34
- fn=macronize,
35
- inputs="text",
36
- outputs="text",
37
- title="Macronizer (Token Classification)"
38
- )
39
 
40
- iface.launch()
 
 
1
+ import html
2
+ import re
3
+
4
  import gradio as gr
5
  import torch
6
+ from torch.nn.functional import softmax
7
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
8
+
9
+ from syllabify import syllabify_joined
10
+ from preprocess import process_word, replace_oxia_with_tonos
11
 
12
+ MODEL_ID = "Ericu950/SyllaMoBert-grc-macronizer-v1"
13
+ MAX_LENGTH = 512
14
 
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
  model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device)
19
+ model.eval()
20
 
21
  id2label = model.config.id2label
22
 
23
+ def preprocess_greek_line(line):
24
+ """
25
+ Normalize, extract, and tokenize a line of Greek text.
26
+
27
+ Steps:
28
+ 1. Normalize oxia to tonos.
29
+ 2. Extract valid Greek words and discard punctuation.
30
+ 3. Expand compound characters and merge diphthongs.
31
+ 4. Flatten the tokens across all words.
32
+
33
+ Args:
34
+ line (str): A full Greek sentence or phrase.
35
+
36
+ Returns:
37
+ list of str: A flat list of tokens (letters or diphthongs).
38
+ """
39
+ # Step 1: Replace oxia with tonos
40
+ line = replace_oxia_with_tonos(line)
41
+
42
+ # Step 2: Extract only Greek characters (ignore punctuation, numbers, etc.)
43
+ words = re.findall(
44
+ r"[ΆΐΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩάέήίΰαβγδεζηθικλμνξοπρςστυφχψωϊϋόύώ"
45
+ r"ἀἁἂἃἄἅἆἇἈἉἊἋἌἍἎ"
46
+ r"ἐἑἒἓἔἕἘἙἜἝ"
47
+ r"ἠἡἢἣἤἥἦἧἨἩἪἫἬἭἮ"
48
+ r"ἰἱἲἳἴἵἶἷἸἹἺἻἼἽἾ"
49
+ r"ὀὁὂὃὄὅὈὉὊὋὌὍ"
50
+ r"ὐὑὒὓὔὕὖὗὙὛὝ"
51
+ r"ὠὡὢὣὤὥὦὧὨὩὪὫὬὭὮὯ"
52
+ r"ὰὲὴὶὸὺὼᾀᾁᾂᾃᾄᾅᾆᾇᾈᾉᾊᾋᾌᾍ"
53
+ r"ᾐᾑᾒᾓᾔᾕᾖᾗᾘᾙᾚᾛᾜᾝ"
54
+ r"ᾠᾡᾢᾣᾤᾥᾦᾧᾨᾩᾪᾫᾬᾭᾮᾯ"
55
+ r"ᾲᾳᾴᾶᾷῂῃῄῆῇῒῖῗῢῤῥῦῧῬῲῳῴῶῷ]+",
56
+ line.lower()
57
+ )
58
+
59
+ # Step 3: Tokenize each word using expansion rules
60
+ token_lists = [process_word(word) for word in words]
61
+
62
+ # Step 4: Flatten token lists across all words
63
+ tokens = [token for tokens in token_lists for token in tokens]
64
+
65
+ return tokens
66
+
67
+ def _normalize_label(raw_label: str) -> int:
68
+ text = raw_label.lower()
69
+ if "long" in text:
70
+ return 1
71
+ if "short" in text:
72
+ return 2
73
+ return 0
74
+
75
+
76
+ def _fallback_preprocess(line: str):
77
+ return re.findall(r"[\wἀ-῾]+|[^\w\s]", line, flags=re.UNICODE)
78
+
79
+
80
+ def _fallback_syllabify(tokens):
81
+ return [t for t in tokens if re.search(r"[\wἀ-῾]", t, flags=re.UNICODE)]
82
+
83
+
84
+ def preprocess_and_syllabify(line: str):
85
+ if preprocess_greek_line and syllabify_joined:
86
+ tokens = preprocess_greek_line(line)
87
+ return syllabify_joined(tokens)
88
+ tokens = _fallback_preprocess(line)
89
+ return _fallback_syllabify(tokens)
90
+
91
+
92
+ def classify_line(line: str):
93
+ syllables = preprocess_and_syllabify(line)
94
+ if not syllables:
95
+ return []
96
+
97
+ inputs = tokenizer(
98
+ syllables,
99
+ is_split_into_words=True,
100
+ return_tensors="pt",
101
+ truncation=True,
102
+ max_length=MAX_LENGTH,
103
+ )
104
+
105
+ if "token_type_ids" in inputs:
106
+ del inputs["token_type_ids"]
107
+
108
+ inputs = {k: v.to(device) for k, v in inputs.items()}
109
+
110
+ with torch.no_grad():
111
+ outputs = model(**inputs)
112
+ probs = softmax(outputs.logits, dim=-1)
113
+ predictions = torch.argmax(probs, dim=-1).squeeze(0).cpu().tolist()
114
+
115
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))
116
+ aligned = []
117
+ syllable_idx = 0
118
+
119
+ for i, token in enumerate(tokens):
120
+ if token in tokenizer.all_special_tokens:
121
+ continue
122
+ if syllable_idx >= len(syllables):
123
+ break
124
+
125
+ pred_id = int(predictions[i])
126
+ label_name = id2label.get(pred_id, str(pred_id))
127
+ normalized = _normalize_label(str(label_name))
128
+ aligned.append((syllables[syllable_idx], normalized))
129
+ syllable_idx += 1
130
+
131
+ return aligned
132
+
133
+
134
+ def _syllable_chip(syllable: str, label_id: int) -> str:
135
+ escaped = html.escape(syllable)
136
+ if label_id == 1:
137
+ return f'<span class="chip long">{escaped}<small>long</small></span>'
138
+ if label_id == 2:
139
+ return f'<span class="chip short">{escaped}<small>short</small></span>'
140
+ return f'<span class="chip clear">{escaped}</span>'
141
+
142
+
143
+ def render_results(text: str):
144
+ lines = [line.strip() for line in text.splitlines() if line.strip()]
145
+ if not lines:
146
+ return "<div class='empty'>Enter one or more Greek lines to classify syllables.</div>", ""
147
+
148
+ cards = []
149
+ export_lines = []
150
+
151
+ for idx, line in enumerate(lines, start=1):
152
+ aligned = classify_line(line)
153
+ chips = "".join(_syllable_chip(syl, label) for syl, label in aligned)
154
+
155
+ cards.append(
156
+ f"""
157
+ <section class="card">
158
+ <div class="line-number">Line {idx}</div>
159
+ <div class="source">{html.escape(line)}</div>
160
+ <div class="chips">{chips or '<span class="chip clear">(no syllables found)</span>'}</div>
161
+ </section>
162
+ """
163
+ )
164
+
165
+ export_lines.append(f"Line {idx}: {line}")
166
+ for syl, label in aligned:
167
+ tag = "long" if label == 1 else "short" if label == 2 else "clear"
168
+ export_lines.append(f" - {syl}: {tag}")
169
+
170
+ html_result = (
171
+ "<div class='legend'><span class='dot long'></span>Long"
172
+ "<span class='dot short'></span>Short"
173
+ "<span class='dot clear'></span>Unmarked</div>"
174
+ + "".join(cards)
175
+ )
176
+
177
+ return html_result, "\n".join(export_lines)
178
+
179
+
180
+ examples = [
181
+ "νεανίας ἀάατός ἐστιν καὶ καλός. τὰ παῖδες τὰ καλά\nκαλὰ μὲν ἠέξευ, καλὰ δ᾽ ἔτραφες, οὐράνιε Ζεῦ,",
182
+ "Ἆρες, Ἄρες βροτολοιγὲ μιαιφόνε τειχεσιπλῆτα\nἈτρεΐδαι τε καὶ ἄλλοι ἐϋκνήμιδες Ἀχαιοί",
183
+ "ἢ τυφλὸς ἤ τις σκνιπὸς ἢ λέγα βλέπων\nψάμμου θαλασσῶν ἢ σκνιπῶν Αἰγυπτίων",
184
+ ]
185
+
186
+
187
+ CSS = """
188
+ @import url('https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@500;600;700&family=Space+Grotesk:wght@400;500;700&display=swap');
189
+
190
+ :root {
191
+ --bg-start: #f2eee6;
192
+ --bg-end: #ddd5c6;
193
+ --ink: #2f2b26;
194
+ --long: #ba3a29;
195
+ --short: #1f6f6d;
196
+ --clear: #7c7369;
197
+ --paper: rgba(255, 251, 244, 0.88);
198
+ }
199
+
200
+ .gradio-container {
201
+ font-family: 'Space Grotesk', sans-serif;
202
+ background: radial-gradient(circle at top left, var(--bg-start), var(--bg-end));
203
+ color: var(--ink);
204
+ }
205
+
206
+ .title h1 {
207
+ font-family: 'Cormorant Garamond', serif;
208
+ font-size: 3rem;
209
+ letter-spacing: 0.02em;
210
+ margin-bottom: 0.2rem;
211
+ }
212
+
213
+ .title p {
214
+ opacity: 0.82;
215
+ }
216
+
217
+ .panel {
218
+ backdrop-filter: blur(8px);
219
+ background: var(--paper);
220
+ border: 1px solid rgba(47, 43, 38, 0.18);
221
+ border-radius: 18px;
222
+ padding: 0.9rem;
223
+ }
224
+
225
+ .legend {
226
+ display: flex;
227
+ align-items: center;
228
+ gap: 0.9rem;
229
+ font-weight: 600;
230
+ margin-bottom: 0.8rem;
231
+ }
232
+
233
+ .dot {
234
+ display: inline-block;
235
+ width: 10px;
236
+ height: 10px;
237
+ border-radius: 999px;
238
+ margin-left: 0.7rem;
239
+ margin-right: 0.25rem;
240
+ }
241
+
242
+ .dot.long { background: var(--long); }
243
+ .dot.short { background: var(--short); }
244
+ .dot.clear { background: var(--clear); }
245
+
246
+ .card {
247
+ background: rgba(255, 255, 255, 0.72);
248
+ border-radius: 14px;
249
+ padding: 0.9rem;
250
+ margin: 0.8rem 0;
251
+ border: 1px solid rgba(47, 43, 38, 0.12);
252
+ animation: rise 420ms ease both;
253
+ }
254
+
255
+ .line-number {
256
+ font-size: 0.8rem;
257
+ font-weight: 700;
258
+ text-transform: uppercase;
259
+ letter-spacing: 0.06em;
260
+ color: #5c544b;
261
+ }
262
+
263
+ .source {
264
+ font-family: 'Cormorant Garamond', serif;
265
+ font-size: 1.45rem;
266
+ margin: 0.25rem 0 0.7rem;
267
+ }
268
+
269
+ .chips {
270
+ display: flex;
271
+ flex-wrap: wrap;
272
+ gap: 0.45rem;
273
+ }
274
+
275
+ .chip {
276
+ display: inline-flex;
277
+ align-items: baseline;
278
+ gap: 0.35rem;
279
+ border-radius: 999px;
280
+ padding: 0.28rem 0.65rem;
281
+ font-family: 'Cormorant Garamond', serif;
282
+ font-size: 1.1rem;
283
+ border: 1px solid transparent;
284
+ }
285
+
286
+ .chip small {
287
+ font-size: 0.75rem;
288
+ font-family: 'Space Grotesk', sans-serif;
289
+ text-transform: uppercase;
290
+ letter-spacing: 0.04em;
291
+ }
292
+
293
+ .chip.long {
294
+ color: var(--long);
295
+ background: rgba(186, 58, 41, 0.09);
296
+ border-color: rgba(186, 58, 41, 0.2);
297
+ }
298
+
299
+ .chip.short {
300
+ color: var(--short);
301
+ background: rgba(31, 111, 109, 0.1);
302
+ border-color: rgba(31, 111, 109, 0.2);
303
+ }
304
+
305
+ .chip.clear {
306
+ color: #544e46;
307
+ background: rgba(116, 108, 95, 0.08);
308
+ border-color: rgba(116, 108, 95, 0.18);
309
+ }
310
+
311
+ .empty {
312
+ padding: 1rem;
313
+ border-radius: 12px;
314
+ background: rgba(255, 255, 255, 0.6);
315
+ border: 1px dashed rgba(47, 43, 38, 0.2);
316
+ }
317
+
318
+ @keyframes rise {
319
+ from { transform: translateY(8px); opacity: 0; }
320
+ to { transform: translateY(0); opacity: 1; }
321
+ }
322
+
323
+ @media (max-width: 820px) {
324
+ .title h1 { font-size: 2.2rem; }
325
+ .source { font-size: 1.25rem; }
326
+ }
327
+ """
328
 
 
 
 
 
 
 
 
 
329
 
330
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
331
+ gr.Markdown(
332
+ """
333
+ <div class="title">
334
+ <h1>Ancient Greek Macronizer</h1>
335
+ <p>Syllable-level long/short classification with a modern, readable presentation.</p>
336
+ </div>
337
+ """
338
+ )
339
 
340
+ with gr.Row():
341
+ with gr.Column(scale=5, elem_classes=["panel"]):
342
+ text_input = gr.Textbox(
343
+ label="Greek Lines",
344
+ lines=8,
345
+ placeholder="Paste one or multiple lines; each line is processed separately.",
346
+ )
347
+ with gr.Row():
348
+ classify_btn = gr.Button("Classify", variant="primary")
349
+ clear_btn = gr.Button("Clear")
350
+ gr.Examples(examples=examples, inputs=text_input, label="Try examples")
351
 
352
+ with gr.Column(scale=6, elem_classes=["panel"]):
353
+ html_output = gr.HTML(label="Styled Results")
354
+ text_output = gr.Textbox(label="Plain Output", lines=12)
355
 
356
+ classify_btn.click(render_results, inputs=text_input, outputs=[html_output, text_output])
357
+ clear_btn.click(lambda: ("", "", ""), outputs=[text_input, html_output, text_output])
358
 
 
 
 
 
 
 
359
 
360
+ if __name__ == "__main__":
361
+ demo.launch()