al1808th commited on
Commit
1bd6db7
·
1 Parent(s): 07f1474

Add model toggle for current and mini checkpoints

Browse files
Files changed (1) hide show
  1. app.py +112 -87
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import html
2
  import re
3
- from typing import List
4
 
5
  import gradio as gr
6
  import torch
@@ -11,17 +11,32 @@ from grc_utils import lower_grc, normalize_word, heavy
11
 
12
  from syllabify import syllabify_joined
13
  from preprocess import process_word
14
- +
15
- MODEL_ID = "Ericu950/SyllaMoBert-grc-macronizer-v1"
 
 
 
 
 
16
  MAX_LENGTH = 512
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19
- model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- model.to(device)
22
- model.eval()
23
 
24
- id2label = model.config.id2label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def preprocess_greek_line(line: str) -> List[str]:
@@ -51,49 +66,51 @@ def preprocess_and_syllabify(line: str):
51
  return syllabify_joined(tokens)
52
 
53
 
54
- def classify_line(line: str):
55
- syllables = preprocess_and_syllabify(line)
56
- if not syllables:
57
- return []
58
 
59
- encoded = tokenizer(
60
- syllables,
61
- is_split_into_words=True,
62
- return_tensors="pt",
63
- truncation=True,
64
- max_length=MAX_LENGTH,
65
- )
 
 
66
 
67
- word_ids = encoded.word_ids(batch_index=0)
68
 
69
- if "token_type_ids" in encoded:
70
- del encoded["token_type_ids"]
71
 
72
- model_inputs = {k: v.to(device) for k, v in encoded.items()}
73
 
74
- with torch.no_grad():
75
- outputs = model(**model_inputs)
76
- probs = F.softmax(outputs.logits, dim=-1)
77
- predictions = torch.argmax(probs, dim=-1).squeeze(0).cpu().tolist()
78
 
79
- aligned = []
80
- seen_word_ids = set()
81
 
82
- for i, word_id in enumerate(word_ids):
83
- if word_id is None:
84
- continue
85
- if word_id in seen_word_ids:
86
- continue
87
- if word_id >= len(syllables):
88
- break
89
 
90
- seen_word_ids.add(word_id)
91
- pred_id = int(predictions[i])
92
- label_name = id2label.get(pred_id, str(pred_id))
93
- normalized = _normalize_label(str(label_name))
94
- aligned.append((syllables[word_id], normalized))
95
 
96
- return aligned
97
 
98
 
99
  def _syllable_chip(syllable: str, label_id: int) -> str:
@@ -105,41 +122,44 @@ def _syllable_chip(syllable: str, label_id: int) -> str:
105
  return f'<span class="chip clear">{escaped}</span>'
106
 
107
 
108
- def render_results(text: str):
109
- lines = [line.strip() for line in text.splitlines() if line.strip()]
110
- if not lines:
111
- return "<div class='empty'>Enter one or more Greek lines to classify syllables.</div>", ""
112
 
113
- cards = []
114
- export_lines = []
115
 
116
- for idx, line in enumerate(lines, start=1):
117
- aligned = classify_line(line)
118
- chips = "".join(_syllable_chip(syl, label) for syl, label in aligned)
119
 
120
- cards.append(
121
- f"""
122
- <section class="card">
123
- <div class="line-number">Line {idx}</div>
124
- <div class="source">{html.escape(line)}</div>
125
- <div class="chips">{chips or '<span class="chip clear">(no syllables found)</span>'}</div>
126
- </section>
127
- """
128
- )
129
-
130
- export_lines.append(f"Line {idx}: {line}")
131
- for syl, label in aligned:
132
- tag = "long" if label == 1 else "short" if label == 2 else "clear"
133
- export_lines.append(f" - {syl}: {tag}")
134
 
135
- html_result = (
136
- "<div class='legend'><span class='dot long'></span>Long"
137
- "<span class='dot short'></span>Short"
138
- "<span class='dot clear'></span>Unmarked</div>"
139
- + "".join(cards)
 
 
 
140
  )
141
 
142
- return html_result, "\n".join(export_lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  examples = [
@@ -454,22 +474,27 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
454
  )
455
 
456
  with gr.Column():
457
- with gr.Column(elem_classes=["panel"]):
458
- text_input = gr.Textbox(
459
- label="Greek Lines",
460
- lines=8,
461
- placeholder="Paste one or multiple lines; each line is processed separately.",
462
- )
463
- with gr.Row():
464
- classify_btn = gr.Button("Classify", variant="primary")
465
- clear_btn = gr.Button("Clear")
466
- gr.Examples(examples=examples, inputs=text_input, label="Try examples")
467
-
468
- with gr.Column(elem_classes=["panel"]):
469
- html_output = gr.HTML(label="Styled Results")
470
- text_output = gr.Textbox(label="Plain Output", lines=12)
471
-
472
- classify_btn.click(render_results, inputs=text_input, outputs=[html_output, text_output])
 
 
 
 
 
473
  clear_btn.click(lambda: ("", "", ""), outputs=[text_input, html_output, text_output])
474
 
475
 
 
1
  import html
2
  import re
3
+ from typing import Dict, List, Tuple
4
 
5
  import gradio as gr
6
  import torch
 
11
 
12
  from syllabify import syllabify_joined
13
  from preprocess import process_word
14
+
15
+ MODEL_OPTIONS: Dict[str, str] = {
16
+ "SyllaMoBert (current)": "Ericu950/SyllaMoBert-grc-macronizer-v1",
17
+ "Macronizer Mini": "Ericu950/macronizer_mini",
18
+ }
19
+ DEFAULT_MODEL_LABEL = "SyllaMoBert (current)"
20
+ DEFAULT_MODEL_ID = MODEL_OPTIONS[DEFAULT_MODEL_LABEL]
21
  MAX_LENGTH = 512
22
 
 
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
24
 
25
+ _MODEL_CACHE: Dict[str, Tuple[AutoTokenizer, AutoModelForTokenClassification, Dict[int, str]]] = {}
26
+
27
+
28
+ def _get_model_bundle(model_id: str) -> Tuple[AutoTokenizer, AutoModelForTokenClassification, Dict[int, str]]:
29
+ if model_id in _MODEL_CACHE:
30
+ return _MODEL_CACHE[model_id]
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
33
+ model = AutoModelForTokenClassification.from_pretrained(model_id)
34
+ model.to(device)
35
+ model.eval()
36
+ id2label = model.config.id2label
37
+
38
+ _MODEL_CACHE[model_id] = (tokenizer, model, id2label)
39
+ return _MODEL_CACHE[model_id]
40
 
41
 
42
  def preprocess_greek_line(line: str) -> List[str]:
 
66
  return syllabify_joined(tokens)
67
 
68
 
69
+ def classify_line(line: str, model_id: str):
70
+ syllables = preprocess_and_syllabify(line)
71
+ if not syllables:
72
+ return []
73
 
74
+ tokenizer, model, id2label = _get_model_bundle(model_id)
75
+
76
+ encoded = tokenizer(
77
+ syllables,
78
+ is_split_into_words=True,
79
+ return_tensors="pt",
80
+ truncation=True,
81
+ max_length=MAX_LENGTH,
82
+ )
83
 
84
+ word_ids = encoded.word_ids(batch_index=0)
85
 
86
+ if "token_type_ids" in encoded:
87
+ del encoded["token_type_ids"]
88
 
89
+ model_inputs = {k: v.to(device) for k, v in encoded.items()}
90
 
91
+ with torch.no_grad():
92
+ outputs = model(**model_inputs)
93
+ probs = F.softmax(outputs.logits, dim=-1)
94
+ predictions = torch.argmax(probs, dim=-1).squeeze(0).cpu().tolist()
95
 
96
+ aligned = []
97
+ seen_word_ids = set()
98
 
99
+ for i, word_id in enumerate(word_ids):
100
+ if word_id is None:
101
+ continue
102
+ if word_id in seen_word_ids:
103
+ continue
104
+ if word_id >= len(syllables):
105
+ break
106
 
107
+ seen_word_ids.add(word_id)
108
+ pred_id = int(predictions[i])
109
+ label_name = id2label.get(pred_id, str(pred_id))
110
+ normalized = _normalize_label(str(label_name))
111
+ aligned.append((syllables[word_id], normalized))
112
 
113
+ return aligned
114
 
115
 
116
  def _syllable_chip(syllable: str, label_id: int) -> str:
 
122
  return f'<span class="chip clear">{escaped}</span>'
123
 
124
 
125
+ def render_results(text: str, model_label: str):
126
+ lines = [line.strip() for line in text.splitlines() if line.strip()]
127
+ if not lines:
128
+ return "<div class='empty'>Enter one or more Greek lines to classify syllables.</div>", ""
129
 
130
+ model_id = MODEL_OPTIONS.get(model_label, DEFAULT_MODEL_ID)
 
131
 
132
+ cards = []
133
+ export_lines = []
 
134
 
135
+ for idx, line in enumerate(lines, start=1):
136
+ aligned = classify_line(line, model_id)
137
+ chips = "".join(_syllable_chip(syl, label) for syl, label in aligned)
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ cards.append(
140
+ f"""
141
+ <section class="card">
142
+ <div class="line-number">Line {idx}</div>
143
+ <div class="source">{html.escape(line)}</div>
144
+ <div class="chips">{chips or '<span class="chip clear">(no syllables found)</span>'}</div>
145
+ </section>
146
+ """
147
  )
148
 
149
+ export_lines.append(f"Line {idx}: {line}")
150
+ for syl, label in aligned:
151
+ tag = "long" if label == 1 else "short" if label == 2 else "clear"
152
+ export_lines.append(f" - {syl}: {tag}")
153
+
154
+ html_result = (
155
+ "<div class='legend'><span class='dot long'></span>Long"
156
+ "<span class='dot short'></span>Short"
157
+ "<span class='dot clear'></span>Unmarked</div>"
158
+ + "".join(cards)
159
+ )
160
+
161
+ export_header = [f"Model: {model_label} ({model_id})", ""]
162
+ return html_result, "\n".join(export_header + export_lines)
163
 
164
 
165
  examples = [
 
474
  )
475
 
476
  with gr.Column():
477
+ with gr.Column(elem_classes=["panel"]):
478
+ model_choice = gr.Radio(
479
+ label="Model",
480
+ choices=list(MODEL_OPTIONS.keys()),
481
+ value=DEFAULT_MODEL_LABEL,
482
+ )
483
+ text_input = gr.Textbox(
484
+ label="Greek Lines",
485
+ lines=8,
486
+ placeholder="Paste one or multiple lines; each line is processed separately.",
487
+ )
488
+ with gr.Row():
489
+ classify_btn = gr.Button("Classify", variant="primary")
490
+ clear_btn = gr.Button("Clear")
491
+ gr.Examples(examples=examples, inputs=text_input, label="Try examples")
492
+
493
+ with gr.Column(elem_classes=["panel"]):
494
+ html_output = gr.HTML(label="Styled Results")
495
+ text_output = gr.Textbox(label="Plain Output", lines=12)
496
+
497
+ classify_btn.click(render_results, inputs=[text_input, model_choice], outputs=[html_output, text_output])
498
  clear_btn.click(lambda: ("", "", ""), outputs=[text_input, html_output, text_output])
499
 
500