unijoh commited on
Commit
a583cec
·
verified ·
1 Parent(s): e334dd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -58
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import re
3
  import string
 
 
4
 
5
  import gradio as gr
6
  import torch
@@ -12,27 +14,28 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
12
  # Config
13
  # ----------------------------
14
  MODEL_ID = "Setur/BRAGD"
15
- TAGS_FILEPATH = "Sosialurin-BRAGD_tags.csv" # must be present in the Space repo
16
- HF_TOKEN = os.getenv("BRAGD") # Space secret name
17
 
 
18
  if not HF_TOKEN:
19
  raise RuntimeError("Missing BRAGD token secret (Space → Settings → Secrets → BRAGD).")
20
 
21
- # Match UPDATED demo.py intervals
22
  INTERVALS = (
23
- (15, 29), # Subcategories (D,B,E,I,P,Q,N,G,R,X,S,C,O,T,s)
24
- (30, 33), # Gender (M,F,N,g)
25
- (34, 36), # Number (S,P,n)
26
- (37, 41), # Case (N,A,D,G,c)
27
- (42, 43), # Article/No-Article (Article,a)
28
- (44, 45), # Proper/Not Proper Noun (Proper,r)
29
- (46, 50), # Degree (P,C,S,A,d)
30
- (51, 53), # Declension (S,W,e)
31
- (54, 60), # Mood (I,M,N,S,P,E,U)
32
- (61, 63), # Voice (A,M,v)
33
- (64, 66), # Tense (P,A,t)
34
- (67, 70), # Person (1,2,3,p)
35
- (71, 72), # Definiteness (D,I)
36
  )
37
 
38
  # ----------------------------
@@ -46,21 +49,29 @@ model.to(device)
46
  model.eval()
47
 
48
  # ----------------------------
49
- # Tag mapping + dict_intervals
50
  # ----------------------------
51
  def load_tag_mappings(tags_filepath: str):
52
  tags_df = pd.read_csv(tags_filepath)
53
 
54
- # Map: Original Tag -> feature vector, and feature vector -> Original Tag
55
- tag_to_features = {row["Original Tag"]: row[1:].values.astype(int) for _, row in tags_df.iterrows()}
56
- features_to_tag = {tuple(row[1:].values.astype(int)): row["Original Tag"] for _, row in tags_df.iterrows()}
 
 
 
 
 
 
 
57
 
58
- vec_len = len(tags_df.columns) - 1
59
- return tag_to_features, features_to_tag, vec_len
60
 
61
- tag_to_features, features_to_tag, VEC_LEN = load_tag_mappings(TAGS_FILEPATH)
62
 
63
- # Safety check: if this fails, you uploaded the wrong CSV for the model
 
 
64
  if hasattr(model, "config") and hasattr(model.config, "num_labels"):
65
  if model.config.num_labels != VEC_LEN:
66
  raise RuntimeError(
@@ -69,12 +80,17 @@ if hasattr(model, "config") and hasattr(model.config, "num_labels"):
69
  "You likely uploaded the wrong tag mapping CSV."
70
  )
71
 
 
 
 
 
 
 
 
72
  def process_tag_features(tag_to_features: dict, intervals):
73
- """Compute allowed intervals per POS (dict_intervals) like your updated demo.py."""
74
  list_of_tags = list(tag_to_features.values())
75
  unique_arrays = [np.array(tpl) for tpl in set(tuple(arr) for arr in list_of_tags)]
76
 
77
- # Collect all feature vectors for each POS class (0..14)
78
  word_type_masks = {}
79
  for wt in range(15):
80
  word_type_masks[wt] = [arr for arr in unique_arrays if arr[wt] == 1]
@@ -97,27 +113,100 @@ def process_tag_features(tag_to_features: dict, intervals):
97
 
98
  return dict_intervals
99
 
 
100
  DICT_INTERVALS = process_tag_features(tag_to_features, INTERVALS)
101
 
102
- def vector_to_tag(vec: torch.Tensor) -> str:
103
- return features_to_tag.get(tuple(vec.int().tolist()), "Unknown Tag")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # ----------------------------
106
- # Tokenization (match updated demo.py)
107
  # ----------------------------
108
  def simp_tok(sentence: str):
109
- """Tokenize into words and punctuation (regex), matching your updated demo.py."""
110
  return re.findall(r"\w+|[" + re.escape(string.punctuation) + "]", sentence)
111
 
112
  # ----------------------------
113
- # Decoding (match updated demo.py logic)
114
  # ----------------------------
115
  def predict_vectors(logits: torch.Tensor, attention_mask: torch.Tensor, begin_tokens, dict_intervals, vec_len: int):
116
- """
117
- Decode one feature-vector per word:
118
- - pick POS (0..14)
119
- - then pick subclasses only in allowed intervals for that POS
120
- """
121
  softmax = torch.nn.Softmax(dim=0)
122
  vectors = []
123
 
@@ -135,7 +224,7 @@ def predict_vectors(logits: torch.Tensor, attention_mask: torch.Tensor, begin_to
135
  wt = torch.argmax(probs).item()
136
  vec[wt] = 1
137
 
138
- # Allowed feature groups for this POS
139
  for (a, b) in dict_intervals.get(wt, []):
140
  seg = pred_logits[a : b + 1]
141
  probs = softmax(seg)
@@ -146,14 +235,95 @@ def predict_vectors(logits: torch.Tensor, attention_mask: torch.Tensor, begin_to
146
 
147
  return vectors
148
 
149
- def tag_sentence(sentence: str, max_len: int = 128):
150
- sentence = sentence.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  if not sentence:
152
- return ""
153
 
154
  tokens = simp_tok(sentence)
155
  if not tokens:
156
- return ""
157
 
158
  enc = tokenizer(
159
  tokens,
@@ -170,7 +340,6 @@ def tag_sentence(sentence: str, max_len: int = 128):
170
  attention_mask = enc["attention_mask"].to(device)
171
  word_ids = enc.word_ids(batch_index=0)
172
 
173
- # begin token mask: first subtoken per word
174
  begin_tokens = []
175
  last = None
176
  for wid in word_ids:
@@ -184,12 +353,11 @@ def tag_sentence(sentence: str, max_len: int = 128):
184
 
185
  with torch.no_grad():
186
  out = model(input_ids=input_ids, attention_mask=attention_mask)
187
- logits = out.logits[0] # [seq_len, num_labels]
188
 
189
  vectors = predict_vectors(logits, attention_mask[0], begin_tokens, DICT_INTERVALS, VEC_LEN)
190
 
191
- # Map vectors back to tokens (one vector per original word)
192
- lines = []
193
  vec_i = 0
194
  seen_word_ids = set()
195
 
@@ -203,25 +371,84 @@ def tag_sentence(sentence: str, max_len: int = 128):
203
 
204
  seen_word_ids.add(wid)
205
  word = tokens[wid] if wid < len(tokens) else "<UNK>"
206
- tag = vector_to_tag(vectors[vec_i]) if vec_i < len(vectors) else "Unknown Tag"
207
- lines.append(f"{word}\t{tag}")
 
 
 
 
208
  vec_i += 1
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  return "\n".join(lines)
211
 
 
212
  # ----------------------------
213
  # Gradio UI
214
  # ----------------------------
215
- demo = gr.Interface(
216
- fn=tag_sentence,
217
- inputs=gr.Textbox(lines=2, label="Setningur"),
218
- outputs=gr.Textbox(lines=12, label="Orð\\tMark"),
219
- title="BRAGD-markarin",
220
- description=(
221
- "Skriv ein setning og fá hann markaðan. "
222
- "Model: Setur/BRAGD. "
223
- ),
224
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  if __name__ == "__main__":
227
- demo.launch()
 
1
  import os
2
  import re
3
  import string
4
+ import json
5
+ from collections import defaultdict
6
 
7
  import gradio as gr
8
  import torch
 
14
  # Config
15
  # ----------------------------
16
  MODEL_ID = "Setur/BRAGD"
17
+ TAGS_FILEPATH = "Sosialurin-BRAGD_tags.csv" # must be in the Space repo
18
+ LABELS_FILEPATH = "tag_labels.json" # add this file to the Space repo
19
 
20
+ HF_TOKEN = os.getenv("BRAGD") # Space secret name
21
  if not HF_TOKEN:
22
  raise RuntimeError("Missing BRAGD token secret (Space → Settings → Secrets → BRAGD).")
23
 
24
+ # Match your UPDATED demo.py intervals
25
  INTERVALS = (
26
+ (15, 29), # Subcategories
27
+ (30, 33), # Gender
28
+ (34, 36), # Number
29
+ (37, 41), # Case
30
+ (42, 43), # Article/No-Article
31
+ (44, 45), # Proper/Not Proper
32
+ (46, 50), # Degree
33
+ (51, 53), # Declension
34
+ (54, 60), # Mood
35
+ (61, 63), # Voice
36
+ (64, 66), # Tense
37
+ (67, 70), # Person
38
+ (71, 72), # Definiteness
39
  )
40
 
41
  # ----------------------------
 
49
  model.eval()
50
 
51
  # ----------------------------
52
+ # Tag mapping (CSV)
53
  # ----------------------------
54
  def load_tag_mappings(tags_filepath: str):
55
  tags_df = pd.read_csv(tags_filepath)
56
 
57
+ feature_cols = list(tags_df.columns[1:])
58
+
59
+ tag_to_features = {
60
+ row["Original Tag"]: row[1:].values.astype(int)
61
+ for _, row in tags_df.iterrows()
62
+ }
63
+ features_to_tag = {
64
+ tuple(row[1:].values.astype(int)): row["Original Tag"]
65
+ for _, row in tags_df.iterrows()
66
+ }
67
 
68
+ vec_len = len(feature_cols)
69
+ return tag_to_features, features_to_tag, vec_len, feature_cols
70
 
 
71
 
72
+ tag_to_features, features_to_tag, VEC_LEN, FEATURE_COLS = load_tag_mappings(TAGS_FILEPATH)
73
+
74
+ # Safety check
75
  if hasattr(model, "config") and hasattr(model.config, "num_labels"):
76
  if model.config.num_labels != VEC_LEN:
77
  raise RuntimeError(
 
80
  "You likely uploaded the wrong tag mapping CSV."
81
  )
82
 
83
+
84
+ def vector_to_tag(vec: torch.Tensor) -> str:
85
+ return features_to_tag.get(tuple(vec.int().tolist()), "Unknown Tag")
86
+
87
+ # ----------------------------
88
+ # Compute allowed intervals per POS
89
+ # ----------------------------
90
  def process_tag_features(tag_to_features: dict, intervals):
 
91
  list_of_tags = list(tag_to_features.values())
92
  unique_arrays = [np.array(tpl) for tpl in set(tuple(arr) for arr in list_of_tags)]
93
 
 
94
  word_type_masks = {}
95
  for wt in range(15):
96
  word_type_masks[wt] = [arr for arr in unique_arrays if arr[wt] == 1]
 
113
 
114
  return dict_intervals
115
 
116
+
117
  DICT_INTERVALS = process_tag_features(tag_to_features, INTERVALS)
118
 
119
+ # ----------------------------
120
+ # Load bilingual labels
121
+ # ----------------------------
122
+ def load_labels(path: str):
123
+ with open(path, "r", encoding="utf-8") as f:
124
+ return json.load(f)
125
+
126
+
127
+ try:
128
+ LABELS = load_labels(LABELS_FILEPATH)
129
+ except Exception:
130
+ LABELS = {"fo": {"global": {}, "by_wc": {}}, "en": {"global": {}, "by_wc": {}}}
131
+
132
+
133
+ def label_for(lang: str, group: str, wc_code: str, code: str) -> str:
134
+ """Word-class-specific first, then global. Always safe to return ""."""
135
+ lang = lang if lang in ("fo", "en") else "fo"
136
+ d = LABELS.get(lang, {})
137
+ by_wc = d.get("by_wc", {})
138
+ glob = d.get("global", {})
139
+
140
+ if wc_code and group in by_wc and wc_code in by_wc[group] and code in by_wc[group][wc_code]:
141
+ return by_wc[group][wc_code][code]
142
+
143
+ if group in glob and code in glob[group]:
144
+ return glob[group][code]
145
+
146
+ return ""
147
+
148
+ # ----------------------------
149
+ # Feature column groups (from CSV headers)
150
+ # ----------------------------
151
+ def _group_from_colname(col: str):
152
+ if col == "Article":
153
+ return ("article", "A")
154
+ if col == "Proper Noun":
155
+ return ("proper", "P")
156
+ if col.startswith("Not-Proper-Noun "):
157
+ return ("proper", col.split()[-1]) # usually r
158
+ if col.startswith("No-Article "):
159
+ return ("article", col.split()[-1]) # usually a
160
+
161
+ prefixes = [
162
+ ("Word Class ", "word_class"),
163
+ ("Subcategory ", "subcategory"),
164
+ ("No-Subcategory ", "subcategory"),
165
+ ("Gender ", "gender"),
166
+ ("No-Gender ", "gender"),
167
+ ("Number ", "number"),
168
+ ("No-Number ", "number"),
169
+ ("Case ", "case"),
170
+ ("No-Case ", "case"),
171
+ ("Degree ", "degree"),
172
+ ("No-Degree ", "degree"),
173
+ ("Declension ", "declension"),
174
+ ("No-Declension ", "declension"),
175
+ ("Mood ", "mood"),
176
+ ("Voice ", "voice"),
177
+ ("No-Voice ", "voice"),
178
+ ("Tense ", "tense"),
179
+ ("No-Tense ", "tense"),
180
+ ("Person ", "person"),
181
+ ("No-Person ", "person"),
182
+ ("Definite ", "definiteness"),
183
+ ("Indefinite ", "definiteness"),
184
+ ]
185
+
186
+ for p, g in prefixes:
187
+ if col.startswith(p):
188
+ code = col.split()[-1]
189
+ return (g, code)
190
+
191
+ return (None, None)
192
+
193
+
194
+ GROUPS = defaultdict(list) # group -> list[(idx, code)]
195
+ for i, col in enumerate(FEATURE_COLS):
196
+ g, code = _group_from_colname(col)
197
+ if g:
198
+ GROUPS[g].append((i, code))
199
 
200
  # ----------------------------
201
+ # Tokenization
202
  # ----------------------------
203
  def simp_tok(sentence: str):
 
204
  return re.findall(r"\w+|[" + re.escape(string.punctuation) + "]", sentence)
205
 
206
  # ----------------------------
207
+ # Decoding
208
  # ----------------------------
209
  def predict_vectors(logits: torch.Tensor, attention_mask: torch.Tensor, begin_tokens, dict_intervals, vec_len: int):
 
 
 
 
 
210
  softmax = torch.nn.Softmax(dim=0)
211
  vectors = []
212
 
 
224
  wt = torch.argmax(probs).item()
225
  vec[wt] = 1
226
 
227
+ # Allowed feature groups
228
  for (a, b) in dict_intervals.get(wt, []):
229
  seg = pred_logits[a : b + 1]
230
  probs = softmax(seg)
 
235
 
236
  return vectors
237
 
238
+
239
+ def describe_vector(vec: torch.Tensor, lang: str) -> str:
240
+ # word class code
241
+ wc_code = ""
242
+ for idx, code in GROUPS.get("word_class", []):
243
+ if int(vec[idx].item()) == 1:
244
+ wc_code = code
245
+ break
246
+
247
+ parts = []
248
+
249
+ wc_label = label_for(lang, "word_class", wc_code, wc_code)
250
+ if wc_code:
251
+ parts.append(f"{wc_code} – {wc_label}" if wc_label else wc_code)
252
+
253
+ order = [
254
+ "subcategory",
255
+ "gender",
256
+ "number",
257
+ "case",
258
+ "article",
259
+ "proper",
260
+ "degree",
261
+ "declension",
262
+ "mood",
263
+ "voice",
264
+ "tense",
265
+ "person",
266
+ "definiteness",
267
+ ]
268
+
269
+ for g in order:
270
+ chosen = None
271
+ for idx, code in GROUPS.get(g, []):
272
+ if int(vec[idx].item()) == 1:
273
+ chosen = code
274
+ break
275
+ if not chosen:
276
+ continue
277
+
278
+ lbl = label_for(lang, g, wc_code, chosen)
279
+
280
+ # Always keep this correct even if labels are missing
281
+ if not lbl:
282
+ if lang == "en":
283
+ FALLBACK = {
284
+ "definiteness": {"D": "definite", "I": "indefinite"},
285
+ "article": {"A": "with suffixed definite article", "a": "no definite suffix"},
286
+ "proper": {"P": "proper noun", "r": "not proper noun"},
287
+ "gender": {"g": "no gender"},
288
+ "number": {"n": "no number"},
289
+ "case": {"c": "no case"},
290
+ "degree": {"d": "no degree"},
291
+ "declension": {"e": "no declension"},
292
+ "voice": {"v": "no voice"},
293
+ "tense": {"t": "no tense"},
294
+ "person": {"p": "no person"},
295
+ "subcategory": {"s": "no subcategory"},
296
+ }
297
+ else:
298
+ FALLBACK = {
299
+ "definiteness": {"D": "bundið", "I": "óbundið"},
300
+ "article": {"A": "við bundnum eftirlið", "a": "uttan bundið eftirlið"},
301
+ "proper": {"P": "sernavn", "r": "ikki sernavn"},
302
+ "gender": {"g": "einki kyn"},
303
+ "number": {"n": "einki tal"},
304
+ "case": {"c": "einki fall"},
305
+ "degree": {"d": "einki stig"},
306
+ "declension": {"e": "eingin bending"},
307
+ "voice": {"v": "eingin søgn"},
308
+ "tense": {"t": "eingin tíð"},
309
+ "person": {"p": "eingin persónur"},
310
+ "subcategory": {"s": "eingin undirflokkur"},
311
+ }
312
+ lbl = FALLBACK.get(g, {}).get(chosen, "")
313
+
314
+ parts.append(f"{chosen} – {lbl}" if lbl else chosen)
315
+
316
+ return "; ".join(parts)
317
+
318
+
319
+ def tag_sentence(sentence: str, lang: str = "fo", max_len: int = 128):
320
+ sentence = (sentence or "").strip()
321
  if not sentence:
322
+ return pd.DataFrame(columns=["Word", "Tag", "Meaning"]), ""
323
 
324
  tokens = simp_tok(sentence)
325
  if not tokens:
326
+ return pd.DataFrame(columns=["Word", "Tag", "Meaning"]), ""
327
 
328
  enc = tokenizer(
329
  tokens,
 
340
  attention_mask = enc["attention_mask"].to(device)
341
  word_ids = enc.word_ids(batch_index=0)
342
 
 
343
  begin_tokens = []
344
  last = None
345
  for wid in word_ids:
 
353
 
354
  with torch.no_grad():
355
  out = model(input_ids=input_ids, attention_mask=attention_mask)
356
+ logits = out.logits[0]
357
 
358
  vectors = predict_vectors(logits, attention_mask[0], begin_tokens, DICT_INTERVALS, VEC_LEN)
359
 
360
+ rows = []
 
361
  vec_i = 0
362
  seen_word_ids = set()
363
 
 
371
 
372
  seen_word_ids.add(wid)
373
  word = tokens[wid] if wid < len(tokens) else "<UNK>"
374
+
375
+ vec = vectors[vec_i] if vec_i < len(vectors) else torch.zeros(VEC_LEN, device=device)
376
+ tag = vector_to_tag(vec)
377
+ meaning = describe_vector(vec, lang)
378
+
379
+ rows.append([word, tag, meaning])
380
  vec_i += 1
381
 
382
+ df = pd.DataFrame(rows, columns=["Word", "Tag", "Meaning"])
383
+ tsv = "\n".join([f"{w}\t{t}\t{m}" for w, t, m in rows])
384
+ return df, tsv
385
+
386
+
387
+ def build_legend(lang: str):
388
+ lang = lang if lang in ("fo", "en") else "fo"
389
+
390
+ if lang == "en":
391
+ title = "### Legend (what the codes mean)"
392
+ hint = "- Tip: hover/copy from the TSV box if you want to paste into spreadsheets or docs."
393
+ wc_title = "#### Word classes"
394
+ missing = "(No label file loaded — add tag_labels.json to the repo root.)"
395
+ else:
396
+ title = "### Markingaryvirlit (hvat kóðurnar merkja)"
397
+ hint = "- Tips: tú kanst copy/paste úr TSV-kassanum inn í skjøl ella rokniskjøl."
398
+ wc_title = "#### Orðaflokkar"
399
+ missing = "(Eingin label-fíla er innlisin — legg tag_labels.json í rótina á repo.)"
400
+
401
+ wc_map = LABELS.get(lang, {}).get("global", {}).get("word_class", {})
402
+
403
+ lines = [title, hint, "", wc_title]
404
+ if wc_map:
405
+ for code in sorted(wc_map.keys()):
406
+ lines.append(f"- **{code}**: {wc_map[code]}")
407
+ else:
408
+ lines.append(f"- {missing}")
409
+
410
  return "\n".join(lines)
411
 
412
+
413
  # ----------------------------
414
  # Gradio UI
415
  # ----------------------------
416
+ theme = gr.themes.Soft()
417
+
418
+ with gr.Blocks(theme=theme, title="BRAGD-markarin") as demo:
419
+ gr.Markdown(
420
+ "## BRAGD-markarin\n"
421
+ "Skriv ein setning og fá hann markaðan.\n\n"
422
+ "**Model:** `Setur/BRAGD`"
423
+ )
424
+
425
+ with gr.Row():
426
+ lang = gr.Dropdown(
427
+ choices=[("Føroyskt", "fo"), ("English", "en")],
428
+ value="fo",
429
+ label="Mál / Language",
430
+ )
431
+
432
+ inp = gr.Textbox(lines=3, label="Setningur / Sentence", placeholder="Skriv her…")
433
+ btn = gr.Button("Marka / Tag", variant="primary")
434
+
435
+ out_df = gr.Dataframe(
436
+ headers=["Word", "Tag", "Meaning"],
437
+ wrap=True,
438
+ interactive=False,
439
+ label="Úrslit / Results",
440
+ )
441
+ out_tsv = gr.Textbox(lines=10, label="Copy/paste (TSV)", interactive=False)
442
+
443
+ with gr.Accordion("Markingaryvirlit / Legend", open=False):
444
+ legend_md = gr.Markdown(build_legend("fo"))
445
+
446
+ def _run(sentence, lang_choice):
447
+ df, tsv = tag_sentence(sentence, lang_choice)
448
+ return df, tsv, build_legend(lang_choice)
449
+
450
+ btn.click(_run, inputs=[inp, lang], outputs=[out_df, out_tsv, legend_md])
451
+ lang.change(lambda l: build_legend(l), inputs=[lang], outputs=[legend_md])
452
 
453
  if __name__ == "__main__":
454
+ demo.launch()