Nathan Schneider commited on
Commit
b3e016c
·
1 Parent(s): a835f40

debug HighlightedText output (#17) using a canned example

Browse files
Files changed (1) hide show
  1. app.py +34 -19
app.py CHANGED
@@ -200,7 +200,7 @@ class MyPipeline(TokenClassificationPipeline):
200
  return results_with_probs
201
 
202
  @spaces.GPU
203
- def classify_tokens(text: str):
204
  """Main function for SNACS text classification that is called in the huggingface space
205
  Input: string to be tagged
206
  Output: HTML styled rendering of tagged outputs
@@ -230,24 +230,37 @@ def classify_tokens(text: str):
230
  "#9edae5"
231
  ][::-1] # reverse-sort to put the lighter colors first
232
 
233
- model_name = "WesScivetti/SNACS_Multilingual"
234
-
235
- tokenizer = AutoTokenizer.from_pretrained(model_name)
236
- model = AutoModelForTokenClassification.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else None)
237
- # ONE pipeline; override aggregation per-call
238
- pipe = MyPipeline(
239
- model=model,
240
- tokenizer=tokenizer,
241
- device=0,
242
- framework="pt"
243
- )
244
-
245
- # tagged spans
246
- results_spans = pipe(text, aggregation_strategy="simple").sort(key=lambda x: x["start"])
247
 
248
- # per-token + probabilities
249
- results_tokens = pipe(text, aggregation_strategy="none", ignore_labels=[]).sort(key=lambda x: x["start"])
250
- print(results_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  # color helper that tolerates B-/I- prefixes
253
  def pick_color(label: str, lbl2color: dict) -> str:
@@ -336,7 +349,9 @@ def classify_tokens(text: str):
336
 
337
  styled_html1 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output1}</div>"
338
  styled_html2 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output2}</div>"
339
- return results_spans, json.dumps(results_spans), json.dumps(results_tokens), styled_html1, table_html, styled_html2
 
 
340
  # except Exception as e:
341
  # # Force the real error into the Space logs
342
  # import traceback, sys
 
200
  return results_with_probs
201
 
202
  @spaces.GPU
203
+ def classify_tokens(text: str, use_canned=False):
204
  """Main function for SNACS text classification that is called in the huggingface space
205
  Input: string to be tagged
206
  Output: HTML styled rendering of tagged outputs
 
230
  "#9edae5"
231
  ][::-1] # reverse-sort to put the lighter colors first
232
 
233
+ if not use_canned:
234
+ model_name = "WesScivetti/SNACS_Multilingual"
235
+
236
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
237
+ model = AutoModelForTokenClassification.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else None)
238
+ # ONE pipeline; override aggregation per-call
239
+ pipe = MyPipeline(
240
+ model=model,
241
+ tokenizer=tokenizer,
242
+ device=0,
243
+ framework="pt"
244
+ )
 
 
245
 
246
+ # tagged spans
247
+ results_spans = pipe(text, aggregation_strategy="simple").sort(key=lambda x: x["start"])
248
+
249
+ # per-token + probabilities
250
+ results_tokens = pipe(text, aggregation_strategy="none", ignore_labels=[]).sort(key=lambda x: x["start"])
251
+ print(results_tokens)
252
+ else: # canned example to test the output display
253
+ text = "fox in socks"
254
+ results_spans = [{"start": 4, "end": 6, "entity_group": "p.Locus-p.Locus",
255
+ "score": 0.46, "word": "in"}]
256
+ results_tokens = [
257
+ {"start": 0, "end": 3, "entity": "O", "score": 1,
258
+ "probabilities": {"O": 1}},
259
+ {"start": 4, "end": 6, "entity": "B-p.Locus-p.Locus", "score": 0.46,
260
+ "probabilities": {"B-p.Locus-p.Locus": 0.46, "B-p.Circumstance-p.Circumstance": 0.3, "B-p.Circumstance-p.Locus": 0.2}},
261
+ {"start": 7, "end": 12, "entity": "O", "score": 1,
262
+ "probabilities": {"O": 1}}
263
+ ]
264
 
265
  # color helper that tolerates B-/I- prefixes
266
  def pick_color(label: str, lbl2color: dict) -> str:
 
349
 
350
  styled_html1 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output1}</div>"
351
  styled_html2 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output2}</div>"
352
+
353
+ simple_output_data = {"text": text, "entities": [{**e} | {"entity_group": display_label(e["entity_group"])} for e in results_spans]}
354
+ return simple_output_data, json.dumps(results_spans), json.dumps(results_tokens), styled_html1, table_html, styled_html2
355
  # except Exception as e:
356
  # # Force the real error into the Space logs
357
  # import traceback, sys