sbompolas's picture
Update app.py
51ce54e verified
import gradio as gr
from gradio import update
import stanza
import pandas as pd
import requests
import traceback
from pathlib import Path
# 1. MODEL VARIANTS & INITIALIZATION
LESBIAN_MODELS = {}
MODEL_VARIANTS = {
"Lesbian-only (UD_Greek-Lesbian)": "sbompolas/Lesbian-Greek-Morphosyntactic-Model",
"Lesbian-augmented (UD_Greek-Lesbian+NGUD)": "sbompolas/NGUD-Lesbian-Morphosyntactic-Model",
"Standard Modern Greek (UD_Greek-GUD)": "sbompolas/GUD",
"Cretan-only (UD_Greek-Cretan)": "sbompolas/Cretan"
}
def download_model_file(url, filename):
try:
resp = requests.get(url, stream=True)
resp.raise_for_status()
with open(filename, "wb") as f:
for chunk in resp.iter_content(8192):
f.write(chunk)
return True
except Exception as e:
print(f"Download failed {filename}: {e}")
return False
def initialize_models():
try:
base = Path("./models")
base.mkdir(exist_ok=True)
for name, repo in MODEL_VARIANTS.items():
out = base / name
out.mkdir(exist_ok=True)
files = {
"tokenizer.pt": f"https://huggingface.co/{repo}/resolve/main/tokenizer.pt",
"lemmatizer.pt": f"https://huggingface.co/{repo}/resolve/main/lemmatizer.pt",
"pos.pt": f"https://huggingface.co/{repo}/resolve/main/pos.pt",
"depparse.pt": f"https://huggingface.co/{repo}/resolve/main/depparse.pt",
}
for fn, url in files.items():
tgt = out / fn
if not tgt.exists() and not download_model_file(url, str(tgt)):
return False, f"Failed to download {fn} for {name}"
cfg = {
'processors': 'tokenize,pos,lemma,depparse',
'lang': 'el',
'use_gpu': False,
'verbose': False,
'tokenize_model_path': str(out/"tokenizer.pt"),
'pos_model_path': str(out/"pos.pt"),
'lemma_model_path': str(out/"lemmatizer.pt"),
'depparse_model_path': str(out/"depparse.pt")
}
LESBIAN_MODELS[name] = stanza.Pipeline(**cfg)
return True, "Models loaded"
except Exception as e:
traceback.print_exc()
return False, str(e)
loaded, load_status = initialize_models()
# 2. CONLL-U / DATAFRAME
def stanza_doc_to_conllu(doc) -> str:
lines = []
for sid, sent in enumerate(doc.sentences, 1):
lines.append(f"# sent_id = {sid}")
lines.append(f"# text = {sent.text}")
for w in sent.words:
fields = [
str(w.id), w.text,
w.lemma or "_", w.upos or "_",
w.xpos or "_", w.feats or "_",
str(w.head) if w.head is not None else "0",
w.deprel or "_", "_", "_"
]
lines.append("\t".join(fields))
lines.append("") # blank line after each sentence
return "\n".join(lines)
def conllu_to_dataframe(conllu: str) -> pd.DataFrame:
cols = ['ID','FORM','LEMMA','UPOS','XPOS','FEATS','HEAD','DEPREL','DEPS','MISC']
rows = []
for line in conllu.splitlines():
if not line:
# empty separator row
rows.append({c: "" for c in cols})
continue
if line.startswith("#"):
key, val = line[2:].split("=", 1)
key, val = key.strip(), val.strip()
if key == "sent_id":
rows.append({'ID': f"# sent_id = {val}", 'FORM': ""})
elif key == "text":
rows.append({'ID': f"# text = {val}", 'FORM': ""})
continue
parts = line.split("\t")
if len(parts) >= 10:
rows.append(dict(zip(cols, parts)))
return pd.DataFrame(rows, columns=cols).fillna("")
# 3. FULL SVG BUILDER (crop top + bottom padding, arrows at start)
def create_single_sentence_svg(sentence_data, sentence_num=1, total_sentences=1):
try:
df = pd.DataFrame(sentence_data) if isinstance(sentence_data, list) else sentence_data
n = len(df)
base_w, min_sp = 100, 30
spacing = max(base_w, (n*base_w + (n-1)*min_sp)/n)
width = max(800, n*spacing + 100)
orig_height = 500
crop_top = 30 # px to remove from top
bottom_pad = 30 # px to add at bottom
height = orig_height - crop_top + bottom_pad
word_y = height - 120
feats_y = word_y + 35
colors = {
'root':'#000000','nsubj':'#2980b9','obj':'#27ae60','det':'#e67e22',
'amod':'#8e44ad','nmod':'#16a085','case':'#34495e','punct':'#7f8c8d',
'cc':'#d35400','conj':'#2c3e50','cop':'#e74c3c','mark':'#9b59b6',
'csubj':'#3498db','xcomp':'#1abc9c','ccomp':'#f39c12','advcl':'#e91e63',
'advmod':'#9c27b0','obl':'#795548','iobj':'#607d8b','fixed':'#ff5722',
'aux':'#ff9800','acl':'#4caf50','appos':'#673ab7','compound':'#009688'
}
svg = [
f'<svg width="{width}" height="{height}" viewBox="0 {crop_top} {width} {orig_height}" '
'xmlns="http://www.w3.org/2000/svg" style="background:white;border:1px solid #eee;"><defs>'
]
for rel, c in colors.items():
svg.append(
f'<marker id="m_{rel}" markerWidth="4" markerHeight="4" '
'markerUnits="userSpaceOnUse" orient="auto-start-reverse" refX="3.5" refY="2">'
f'<path d="M0,0 L4,2 L0,4Z" fill="{c}"/></marker>'
)
svg.append('</defs><g>')
# compute x positions
xpos = {
int(r['ID']): 50 + (int(r['ID']) - 1) * spacing
for _, r in df.iterrows() if str(r['ID']).isdigit()
}
used_spans = []
for _, r in df.iterrows():
if not str(r['ID']).isdigit():
continue
i, h = int(r['ID']), int(r['HEAD'])
rel, c = r['DEPREL'], colors.get(r['DEPREL'], '#000')
x1 = xpos[i]
if h == 0:
# ROOT line
svg.append(
f'<line x1="{x1}" y1="{word_y-15}" x2="{x1}" y2="50" '
f'stroke="{c}" stroke-width="1.5"/>'
)
mid = (word_y-15 + 50) / 2
svg.append(
f'<rect x="{x1-15}" y="{mid-8}" width="30" height="14" '
f'fill="white" stroke="{c}" rx="2"/>'
)
svg.append(
f'<text x="{x1}" y="{mid+2}" text-anchor="middle" '
f'fill="{c}" font-family="Arial" font-size="8" font-weight="bold">ROOT</text>'
)
else:
x2 = xpos.get(h, x1)
span = (min(i, h), max(i, h))
lvl = 0
conflict = True
while conflict:
conflict = False
for (es, el), used_lvl in used_spans:
if used_lvl == lvl and not (span[1] < es or span[0] > el):
lvl += 1
conflict = True
break
used_spans.append((span, lvl))
dist = abs(x2 - x1)
arc_h = min(40 + dist * 0.15, 100) + lvl * 35
midx, cty = (x1 + x2) / 2, word_y - arc_h
path_d = f'M{x1} {word_y-15} Q{midx} {cty} {x2} {word_y-15}'
svg.append(
f'<path d="{path_d}" stroke="{c}" fill="none" stroke-width="1.5" '
f'marker-start="url(#m_{rel})"/>'
)
amx = 0.25*x1 + 0.5*midx + 0.25*x2
amy = 0.25*(word_y-15) + 0.5*cty + 0.25*(word_y-15)
lw = len(rel)*6 + 8
svg.append(
f'<rect x="{amx-lw/2}" y="{amy-8}" width="{lw}" height="14" '
f'fill="white" stroke="{c}" rx="2"/>'
)
svg.append(
f'<text x="{amx}" y="{amy+2}" text-anchor="middle" '
f'fill="{c}" font-family="Arial" font-size="8" font-weight="bold">{rel}</text>'
)
# draw tokens + annotations
for _, r in df.iterrows():
if not str(r['ID']).isdigit():
continue
x = xpos[int(r['ID'])]
svg.append(
f'<text x="{x}" y="{word_y}" text-anchor="middle" '
f'font-family="Arial" font-size="13" font-weight="bold">{r["FORM"]}</text>'
)
ann = []
if r['UPOS'] and r['UPOS'] != '_': ann.append(f"upos={r['UPOS']}")
if r['LEMMA'] not in ('_', r['FORM']): ann.append(f"lemma={r['LEMMA']}")
if r['FEATS'] and r['FEATS'] not in ('', '_'):
for f in r['FEATS'].split('|'):
if '=' in f:
ann.append(f)
for i, a in enumerate(ann):
svg.append(
f'<text x="{x}" y="{feats_y + i*12}" text-anchor="middle" '
f'font-family="Arial" font-size="7" fill="#666">{a}</text>'
)
svg.append('</g></svg>')
return "".join(svg)
except Exception as e:
return f"<p>Error creating SVG: {e}</p>"
# 4. PROCESS & DROPDOWN-UPDATES
def process_text(text, variant):
if not text.strip():
return (
gr.HTML.update(value="<p>No data</p>"),
gr.Dropdown.update(choices=[], value=None),
[], "", pd.DataFrame()
)
pipe = LESBIAN_MODELS.get(variant)
if pipe is None:
return (
gr.HTML.update(value="<p>Error: model not loaded</p>"),
gr.Dropdown.update(choices=[], value=None),
[], "", pd.DataFrame()
)
doc = pipe(text)
conllu = stanza_doc_to_conllu(doc)
df = conllu_to_dataframe(conllu)
sentences = []
for sent in doc.sentences:
payload = [{
'ID': w.id,
'FORM': w.text,
'LEMMA': w.lemma or "_",
'UPOS': w.upos or "_",
'XPOS': w.xpos or "_",
'FEATS': w.feats or "_",
'HEAD': w.head or 0,
'DEPREL': w.deprel or "_"
} for w in sent.words]
sentences.append(payload)
sent_ids = [str(i+1) for i in range(len(sentences))]
dd_upd = update(choices=sent_ids, value=sent_ids[0] if sent_ids else None)
init_svg = create_single_sentence_svg(sentences[0]) if sentences else "<p>No data</p>"
return init_svg, dd_upd, sentences, conllu, df
def update_svg(selected_id, sentences):
try:
idx = int(selected_id) - 1
return create_single_sentence_svg(sentences[idx])
except:
return "<p>Invalid selection</p>"
# 5. BUILD GRADIO UI
def create_app():
with gr.Blocks(title="Parser for MG Dialects") as app:
gr.Markdown("# Morphosyntactic Parser for MG Dialects")
if not loaded:
gr.Markdown(f"❌ Load error: {load_status}")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Input Text", lines=4, placeholder="Εισάγετε κείμενο…")
mdl = gr.Radio(choices=list(MODEL_VARIANTS.keys()), value="Lesbian-only", label="Model Variant")
btn = gr.Button("Parse", variant="primary")
with gr.Row():
with gr.Column():
svg_out = gr.HTML("<p>No visualization</p>")
sentence_dd = gr.Dropdown(label="Choose sentence", choices=[])
sentences_st = gr.State([])
with gr.Row():
with gr.Column():
conllu_out = gr.Textbox(label="CoNLL-U", lines=10, show_copy_button=True)
table_out = gr.Dataframe(label="Token Table")
btn.click(
fn=process_text,
inputs=[txt, mdl],
outputs=[svg_out, sentence_dd, sentences_st, conllu_out, table_out]
)
sentence_dd.change(fn=update_svg, inputs=[sentence_dd, sentences_st], outputs=svg_out)
return app
if __name__ == "__main__":
create_app().launch()