macronizer / app.py
Urdatorn's picture
mcp args
3888c53
import html
import re
from typing import Dict, List, Literal, Tuple
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from transformers import AutoModelForTokenClassification, AutoTokenizer
from grc_utils import lower_grc, normalize_word, vowel, only_bases
from syllabify import syllabify_joined
from preprocess import process_word
MODEL_OPTIONS: Dict[str, str] = {
"Pretrained ModernBERT": "Ericu950/SyllaMoBert-grc-macronizer-v1",
"Fine-tuned RoBERTa": "Ericu950/macronizer_mini",
}
DEFAULT_MODEL_LABEL = "Fine-tuned RoBERTa"
DEFAULT_MODEL_ID = MODEL_OPTIONS[DEFAULT_MODEL_LABEL]
MODEL_ALIASES: Dict[str, str] = {
"modernbert": MODEL_OPTIONS["Pretrained ModernBERT"],
"syllamobert": MODEL_OPTIONS["Pretrained ModernBERT"],
"roberta": MODEL_OPTIONS["Fine-tuned RoBERTa"],
}
MAX_LENGTH = 512
_MODEL_CACHE: Dict[str, Tuple[AutoTokenizer, AutoModelForTokenClassification, Dict[int, str]]] = {}
def _current_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _resolve_model_id(model_label: str | None) -> str:
if model_label in MODEL_OPTIONS:
return MODEL_OPTIONS[model_label]
alias = (model_label or "").strip().lower()
return MODEL_ALIASES.get(alias, DEFAULT_MODEL_ID)
def _get_model_bundle(model_id: str) -> Tuple[AutoTokenizer, AutoModelForTokenClassification, Dict[int, str]]:
device = _current_device()
if model_id in _MODEL_CACHE:
tokenizer, model, id2label = _MODEL_CACHE[model_id]
model.to(device)
return _MODEL_CACHE[model_id]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id)
model.to(device)
model.eval()
id2label = model.config.id2label
_MODEL_CACHE[model_id] = (tokenizer, model, id2label)
return _MODEL_CACHE[model_id]
def preprocess_greek_line(line: str) -> List[str]:
# Normalize accents and keep only Greek-letter word spans.
normalized = normalize_word(line)
lower = lower_grc(normalized)
words = lower.split()
token_lists = [process_word(word) for word in words]
return [token for tokens in token_lists for token in tokens]
def _normalize_label(raw_label: str) -> int:
text = raw_label.lower()
if "long" in text:
return 1
if "short" in text:
return 2
if text.endswith("_1") or text == "1":
return 1
if text.endswith("_2") or text == "2":
return 2
return 0
def preprocess_and_syllabify(line: str):
tokens = preprocess_greek_line(line)
return syllabify_joined(tokens)
def classify_line_per_word(line: str, model_id: str) -> List[Tuple[str, List[Tuple[str, int]]]]:
"""
Classify each word separately to preserve word boundaries.
Returns list of (word, aligned_syllables) tuples.
"""
line_for_matching = line.replace("ς", "σ")
parts = re.findall(r"\S+|\s+", line)
parts_for_matching = re.findall(r"\S+|\s+", line_for_matching)
result = []
for part, part_for_matching in zip(parts, parts_for_matching):
if part_for_matching.isspace():
result.append((part_for_matching, [])) # Spaces have no aligned syllables
continue
# Classify this word independently
aligned = classify_line(part_for_matching, model_id)
result.append((part, aligned))
return result
def classify_line(line: str, model_id: str):
syllables = preprocess_and_syllabify(line)
if not syllables:
return []
device = _current_device()
tokenizer, model, id2label = _get_model_bundle(model_id)
encoded = tokenizer(
syllables,
is_split_into_words=True,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
)
word_ids = encoded.word_ids(batch_index=0)
if "token_type_ids" in encoded:
del encoded["token_type_ids"]
model_inputs = {k: v.to(device) for k, v in encoded.items()}
with torch.no_grad():
outputs = model(**model_inputs)
probs = F.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(probs, dim=-1).squeeze(0).cpu().tolist()
aligned = []
seen_word_ids = set()
for i, word_id in enumerate(word_ids):
if word_id is None:
continue
if word_id in seen_word_ids:
continue
if word_id >= len(syllables):
break
seen_word_ids.add(word_id)
pred_id = int(predictions[i])
label_name = id2label.get(pred_id, str(pred_id))
normalized = _normalize_label(str(label_name))
aligned.append((syllables[word_id], normalized))
return aligned
def _style_syllable_vowels(syllable: str, label_id: int) -> str:
if label_id == 1:
label_class = "long"
elif label_id == 2:
label_class = "short"
else:
return html.escape(syllable)
out = []
for ch in syllable:
escaped = html.escape(ch)
if vowel(ch):
out.append(f'<span class="vowel {label_class}">{escaped}</span>')
else:
out.append(escaped)
return "".join(out)
def _is_greek(ch: str) -> bool:
return "\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff"
def _expanded_greek_unit_count(text: str) -> int:
count = 0
for ch in text:
if not _is_greek(ch):
continue
normalized = lower_grc(normalize_word(ch)).replace("ς", "σ").replace("ῥ", "ρ")
if normalized in {"ζ", "ξ", "ψ"}:
count += 2
else:
count += 1
return count
def _segment_reference_by_syllable(
reference_word: str,
syllables: List[Tuple[str, int]],
) -> List[Tuple[str, int]]:
counted = [
(syllable, label, _expanded_greek_unit_count(syllable))
for syllable, label in syllables
]
counted = [(syllable, label, count) for syllable, label, count in counted if count > 0]
if not counted:
return [(reference_word, 0)] if reference_word else []
segments: List[List[object]] = [["", counted[0][1]]]
syllable_idx = 0
units_in_syllable = 0
for ch in reference_word:
units = _expanded_greek_unit_count(ch)
if units == 0:
segments[-1][0] = str(segments[-1][0]) + ch
continue
segments[-1][0] = str(segments[-1][0]) + ch
units_in_syllable += units
while syllable_idx < len(counted) and units_in_syllable >= counted[syllable_idx][2]:
units_in_syllable -= counted[syllable_idx][2]
syllable_idx += 1
if syllable_idx >= len(counted):
break
segments.append(["", counted[syllable_idx][1]])
return [(str(text), int(label)) for text, label in segments if text]
def _mark_syllable_plain(syllable: str, label_id: int) -> str:
if label_id not in (1, 2):
return syllable
marker = "_" if label_id == 1 else "^"
chars = list(syllable)
# Find the last vowel, skipping trailing non-letter characters (like punctuation)
vowel_idx = -1
for i in range(len(chars) - 1, -1, -1):
if vowel(chars[i]):
vowel_idx = i
break
# Skip markup characters and non-letter characters
ch = chars[i]
if ch not in "^_" and ("\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff"):
# It's a Greek letter but not a vowel
continue
if not ("\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff") and ch not in "^_":
# It's not a Greek letter or markup - it's punctuation or other, skip
continue
if vowel_idx >= 0:
# Found a vowel, insert marker after it
return "".join(chars[:vowel_idx + 1]) + marker + "".join(chars[vowel_idx + 1:])
# No vowel found, find the last Greek letter
for i in range(len(chars) - 1, -1, -1):
if "\u0370" <= chars[i] <= "\u03ff" or "\u1f00" <= chars[i] <= "\u1fff":
# Insert marker after the last Greek letter
return "".join(chars[:i + 1]) + marker + "".join(chars[i + 1:])
# No Greek letters found, append marker at the end
return syllable + marker
def _to_final_sigma(text: str) -> str:
# Step 3: in rendered output, only word-final sigmas become final-sigma.
def _convert_word(token: str) -> str:
if not token.strip():
return token
chars = list(token)
last_greek_idx = -1
for i, ch in enumerate(chars):
if "\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff":
last_greek_idx = i
if last_greek_idx != -1 and chars[last_greek_idx] == "σ":
chars[last_greek_idx] = "ς"
return "".join(chars)
return "".join(_convert_word(tok) for tok in re.findall(r"\S+|\s+", text))
def _restore_expanded_word(marked_word: str, reference_word: str) -> str:
restored = marked_word.replace("δσ", "ζ").replace("κσ", "ξ").replace("πσ", "ψ")
ref_norm = lower_grc(normalize_word(reference_word))
if "ῥ" in ref_norm:
rho_idx = restored.find("ρ")
if rho_idx != -1:
restored = restored[:rho_idx] + "ῥ" + restored[rho_idx + 1 :]
# Apply case from reference_word to restored
restored = _apply_case_from_reference(restored, reference_word)
# Preserve original final sigma from reference
restored = _preserve_final_sigma_from_reference(restored, reference_word)
return _to_final_sigma(restored)
def _apply_case_from_reference(text: str, reference: str) -> str:
"""Apply case from reference word to text (only for Greek letters)."""
result = []
ref_idx = 0
for char in text:
# Skip markup characters
if char in "^_":
result.append(char)
continue
# For Greek letters, find corresponding reference character and apply case
if "\u0370" <= char <= "\u03ff" or "\u1f00" <= char <= "\u1fff":
# Find next Greek letter in reference
while ref_idx < len(reference) and not ("\u0370" <= reference[ref_idx] <= "\u03ff" or "\u1f00" <= reference[ref_idx] <= "\u1fff"):
ref_idx += 1
if ref_idx < len(reference):
ref_char = reference[ref_idx]
# Check if reference character is uppercase
if ref_char.isupper() or ref_char != lower_grc(ref_char):
# Try to apply uppercase version
upper_version = char.upper()
if upper_version != char: # Character has an uppercase form
result.append(upper_version)
else:
result.append(char)
else:
result.append(char)
ref_idx += 1
else:
result.append(char)
else:
result.append(char)
return "".join(result)
def _preserve_final_sigma_from_reference(text: str, reference: str) -> str:
"""If reference has final sigma ς at word end, preserve it in text."""
# Simply copy final sigmas from reference to text at word boundaries
# Split both into tokens
text_tokens = re.findall(r"\S+|\s+", text)
ref_tokens = re.findall(r"\S+|\s+", reference)
result = []
for text_token, ref_token in zip(text_tokens, ref_tokens):
if text_token.isspace() or ref_token.isspace():
result.append(text_token)
continue
# Find last Greek letter in both tokens
text_last_greek_idx = -1
ref_last_greek_idx = -1
for i in range(len(text_token) - 1, -1, -1):
ch = text_token[i]
if "\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff":
text_last_greek_idx = i
break
for i in range(len(ref_token) - 1, -1, -1):
ch = ref_token[i]
if "\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff":
ref_last_greek_idx = i
break
# If reference ends with final sigma ς, convert text's final sigma to match
if ref_last_greek_idx >= 0 and ref_token[ref_last_greek_idx] == "ς":
if text_last_greek_idx >= 0 and text_token[text_last_greek_idx] == "σ":
text_token = text_token[:text_last_greek_idx] + "ς" + text_token[text_last_greek_idx+1:]
result.append(text_token)
return "".join(result)
def _consume_word_alignment(
aligned: List[Tuple[str, int]],
start_idx: int,
expected_syllables: List[str],
) -> Tuple[List[Tuple[str, int]], int]:
if start_idx >= len(aligned):
return [], start_idx
expected_bases = only_bases("".join(expected_syllables))
if expected_bases:
taken: List[Tuple[str, int]] = []
i = start_idx
while i < len(aligned):
taken.append(aligned[i])
current_bases = only_bases("".join(s for s, _ in taken))
if current_bases == expected_bases:
return taken, i + 1
if len(current_bases) > len(expected_bases) and not current_bases.startswith(expected_bases):
break
i += 1
fallback_count = len(expected_syllables)
if fallback_count <= 0:
return [], start_idx
end_idx = min(len(aligned), start_idx + fallback_count)
return aligned[start_idx:end_idx], end_idx
def _render_plain_line_per_word(line: str, model_id: str) -> str:
"""Render plain line by processing each word separately."""
line_for_matching = line.replace("ς", "σ")
parts = re.findall(r"\S+|\s+", line)
parts_for_matching = re.findall(r"\S+|\s+", line_for_matching)
out_parts: List[str] = []
for part, part_for_matching in zip(parts, parts_for_matching):
if part_for_matching.isspace():
# Preserve original spacing exactly.
out_parts.append(part)
continue
plain_word, _ = _render_word_plain_and_syllables(part, part_for_matching, model_id)
out_parts.append(plain_word)
return "".join(out_parts)
def _render_word_plain_and_syllables(
part: str,
part_for_matching: str,
model_id: str,
) -> Tuple[str, List[Tuple[str, int]]]:
aligned = classify_line(part_for_matching, model_id)
if not aligned:
return part, [(part, 0)]
normalized_word = lower_grc(normalize_word(part_for_matching)).replace("ς", "σ")
expected_tokens = process_word(normalized_word)
expected_syllables = syllabify_joined(expected_tokens)
taken, _ = _consume_word_alignment(aligned, 0, expected_syllables)
if not taken:
return part, [(part, 0)]
segments = _segment_reference_by_syllable(part, taken)
plain_word = "".join(_mark_syllable_plain(syl, label) for syl, label in segments)
return plain_word, segments
def _render_styled_line_per_word(line: str, model_id: str) -> str:
line_for_matching = line.replace("ς", "σ")
parts = re.findall(r"\S+|\s+", line)
parts_for_matching = re.findall(r"\S+|\s+", line_for_matching)
out_parts: List[str] = []
for part, part_for_matching in zip(parts, parts_for_matching):
if part_for_matching.isspace():
out_parts.append(html.escape(part_for_matching))
continue
_, syllables = _render_word_plain_and_syllables(part, part_for_matching, model_id)
out_parts.extend(_style_syllable_vowels(syl, label) for syl, label in syllables)
return "".join(out_parts)
def _convert_final_sigmas(text: str) -> str:
"""Convert word-final σ to ς (final sigma) for readability."""
# Find all words (sequences of non-space characters that include Greek letters)
def convert_word(match):
word = match.group(0)
# Find the last Greek letter in the word
for i in range(len(word) - 1, -1, -1):
ch = word[i]
if "\u0370" <= ch <= "\u03ff" or "\u1f00" <= ch <= "\u1fff":
# Found the last Greek letter
if ch == "σ":
# Convert medial sigma to final sigma
return word[:i] + "ς" + word[i+1:]
break
return word
# Replace word-final σ with ς
return re.sub(r"\S+", convert_word, text)
@spaces.GPU(duration=120)
def macronize_ui(text: str, model_label: str):
lines = [line.strip() for line in text.splitlines() if line.strip()]
if not lines:
return "<div class='empty'>Enter one or more Greek lines to classify syllables.</div>", ""
model_id = _resolve_model_id(model_label)
cards = []
plain_lines = []
for idx, line in enumerate(lines, start=1):
plain_line = _render_plain_line_per_word(line, model_id)
styled_line = _render_styled_line_per_word(line, model_id)
cards.append(
f"""
<section class="card">
<div class="line-number">Line {idx}</div>
<div class="line-output">{styled_line or '<span class="empty-inline">(no syllables found)</span>'}</div>
</section>
"""
)
plain_lines.append(plain_line if plain_line else "(no syllables found)")
html_result = (
"<div class='legend'><span class='dot long'></span>Long"
"<span class='dot short'></span>Short</div>"
+ "".join(cards)
)
return html_result, "\n".join(plain_lines)
@spaces.GPU(duration=120)
def macronize(text: str, model: Literal["roberta", "modernbert"] = "roberta") -> str:
"""
Mark Ancient Greek alphas, iotas and ypsilons with trailing carets (^) if long and undescores (_) if short.
Arguments:
- `text`: One or more lines of Ancient Greek text to classify. Each line is processed separately.
- `model`: Which model to use for classification. Options are "modernbert" (Ericu950/SyllaMoBert-grc-macronizer-v1, a pretrained ModernBERT 0.1B, slightly higher benchmark score but larger and slower) and "roberta" (Ericu950/macronizer_mini, a fine-tuned RoBERTa 11M, smaller and faster).
Made by Albin Thörn Cleland (Lund university) and Eric Cullhed (Uppsala university).
Training the models was made possible by resources provided by the National Academic Infrastructure for Supercomputing in Sweden (NAISS), partially funded by the Swedish Research Council (grant agreement no. 2022-06725).
"""
lines = [line.strip() for line in text.splitlines() if line.strip()]
if not lines:
return ""
model_id = _resolve_model_id(model)
return "\n".join(_render_plain_line_per_word(line, model_id) for line in lines)
examples = [
"μῆνιν ἄειδε θεὰ Πηληϊάδεω Ἀχιλῆος",
"νεανίας, ἀάατός",
"Ἆρες, Ἄρες βροτολοιγὲ μιαιφόνε τειχεσιπλῆτα\nἈτρεΐδαι τε καὶ ἄλλοι ἐϋκνήμιδες Ἀχαιοί",
]
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@500;600;700&family=Space+Grotesk:wght@400;500;700&display=swap');
:root {
--long: #b93525;
--short: #087b83;
--chip-long-color: var(--long);
--chip-short-color: var(--short);
}
.gradio-container {
font-family: 'Space Grotesk', sans-serif;
color: var(--body-text-color);
}
.title h1 {
font-family: 'Cormorant Garamond', serif;
font-size: 3rem;
letter-spacing: 0.02em;
margin-bottom: 0.2rem;
color: var(--body-text-color);
}
.title p {
opacity: 0.82;
color: var(--body-text-color);
}
.panel {
background: var(--block-background-fill);
border: 1px solid var(--block-border-color);
border-radius: var(--block-radius);
padding: 0.9rem;
}
#try-examples .examples {
grid-template-columns: repeat(auto-fit, minmax(min(100%, 320px), 1fr));
align-items: stretch;
max-width: 100%;
}
#try-examples .example {
min-height: 0;
height: auto;
}
#try-examples .example-content,
#try-examples .example-text-content {
width: 100%;
}
#try-examples .example-text {
display: block;
white-space: pre-wrap;
overflow: visible;
text-overflow: clip;
line-height: 1.35;
}
.legend {
display: flex;
align-items: center;
gap: 0.9rem;
font-weight: 600;
margin-bottom: 0.8rem;
}
.dot {
display: inline-block;
width: 10px;
height: 10px;
border-radius: 999px;
margin-left: 0.7rem;
margin-right: 0.25rem;
}
.dot.long { background: var(--long); }
.dot.short { background: var(--short); }
.card {
background: var(--block-background-fill);
border-radius: var(--block-radius);
padding: 0.9rem;
margin: 0.8rem 0;
border: 1px solid var(--block-border-color);
animation: rise 420ms ease both;
color: var(--body-text-color);
}
.line-number {
font-size: 0.8rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.06em;
color: var(--body-text-color-subdued);
}
.line-output {
font-family: 'Cormorant Garamond', serif;
font-size: 1.5rem;
line-height: 1.45;
color: var(--body-text-color);
overflow-wrap: anywhere;
white-space: pre-wrap;
}
.vowel.long {
color: var(--chip-long-color);
font-weight: 700;
}
.vowel.short {
color: var(--chip-short-color);
font-weight: 700;
}
.empty {
padding: 1rem;
border-radius: 12px;
background: var(--block-background-fill);
border: 1px dashed var(--block-border-color);
color: var(--body-text-color);
}
@keyframes rise {
from { transform: translateY(8px); opacity: 0; }
to { transform: translateY(0); opacity: 1; }
}
@media (max-width: 820px) {
.title h1 { font-size: 2.2rem; }
}
"""
with gr.Blocks(title="The First Ancient Greek Macronizer") as demo:
gr.Markdown(
"""
<div class="title">
<h1>The First Ancient Greek Macronizer</h1>
<p>
Enter Ancient Greek text to have the alphas, iotas and ypsilons marked as long or short. <br><br>There are two models: a pretrained ModernBERT 0.1B with the highest Norma Syllabarum Graecarum benchmark score (Ericu950/SyllaMoBert-grc-macronizer-v1) and the default, a tiny and extremely fast fine-tuned RoBERTa 11M (Ericu950/macronizer_mini), with only slightly lower benchmark score. <br><br>Made by Albin Thörn Cleland (Lund university) and Eric Cullhed (Uppsala university). Training the models was made possible by resources provided by the National Academic Infrastructure for Supercomputing in Sweden (NAISS), partially funded by the Swedish Research Council (grant agreement no. 2022-06725).
</p>
</div>
"""
)
with gr.Column():
with gr.Column(elem_classes=["panel"]):
model_choice = gr.Radio(
label="Model",
choices=list(MODEL_OPTIONS.keys()),
value=DEFAULT_MODEL_LABEL,
)
text_input = gr.Textbox(
label="Greek Lines",
lines=8,
placeholder="Paste one or multiple lines; each line is processed separately.",
)
with gr.Row():
classify_btn = gr.Button("Classify", variant="primary")
clear_btn = gr.Button("Clear")
gr.Examples(
examples=examples,
inputs=text_input,
label="Try examples",
elem_id="try-examples",
api_name=False,
api_visibility="private",
)
with gr.Column(elem_classes=["panel"]):
html_output = gr.HTML(label="Styled Results")
text_output = gr.Textbox(label="Plain Output", lines=12, buttons=["copy"])
gr.Markdown(
"""
### Python API
```python
from gradio_client import Client
client = Client("Urdatorn/macronizer")
result = client.predict(
text="μῆνιν ἄειδε θεὰ Πηληϊάδεω Ἀχιλῆος",
model="roberta",
api_name="/macronize",
)
print(result)
```
Model options are `model="modernbert"` for the slightly higher benchmark score but larger and slower pretrained 0.1B model, or `model="roberta"` for the smaller and faster fine-tuned RoBERTa 11M model.
"""
)
api_text = gr.Textbox(label="text", value="μῆνιν ἄειδε θεὰ Πηληϊάδεω Ἀχιλῆος", visible=False)
api_model = gr.Textbox(label="model", value="roberta", visible=False)
api_output = gr.Textbox(label="plain_output", visible=False)
api_btn = gr.Button("API", visible=False)
classify_btn.click(
macronize_ui,
inputs=[text_input, model_choice],
outputs=[html_output, text_output],
api_name=False,
api_visibility="private",
)
clear_btn.click(
lambda: ("", "", ""),
outputs=[text_input, html_output, text_output],
api_name=False,
api_visibility="private",
)
api_btn.click(
macronize,
inputs=[api_text, api_model],
outputs=api_output,
api_name="macronize",
)
if __name__ == "__main__":
demo.launch(
css=CSS,
footer_links=["settings", "api"],
mcp_server=True,
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan", neutral_hue="slate"),
)