Spaces:
Sleeping
Sleeping
Nathan Schneider
commited on
Commit
·
b3e016c
1
Parent(s):
a835f40
debug HighlightedText output (#17) using a canned example
Browse files
app.py
CHANGED
|
@@ -200,7 +200,7 @@ class MyPipeline(TokenClassificationPipeline):
|
|
| 200 |
return results_with_probs
|
| 201 |
|
| 202 |
@spaces.GPU
|
| 203 |
-
def classify_tokens(text: str):
|
| 204 |
"""Main function for SNACS text classification that is called in the huggingface space
|
| 205 |
Input: string to be tagged
|
| 206 |
Output: HTML styled rendering of tagged outputs
|
|
@@ -230,24 +230,37 @@ def classify_tokens(text: str):
|
|
| 230 |
"#9edae5"
|
| 231 |
][::-1] # reverse-sort to put the lighter colors first
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
# tagged spans
|
| 246 |
-
results_spans = pipe(text, aggregation_strategy="simple").sort(key=lambda x: x["start"])
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
# color helper that tolerates B-/I- prefixes
|
| 253 |
def pick_color(label: str, lbl2color: dict) -> str:
|
|
@@ -336,7 +349,9 @@ def classify_tokens(text: str):
|
|
| 336 |
|
| 337 |
styled_html1 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output1}</div>"
|
| 338 |
styled_html2 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output2}</div>"
|
| 339 |
-
|
|
|
|
|
|
|
| 340 |
# except Exception as e:
|
| 341 |
# # Force the real error into the Space logs
|
| 342 |
# import traceback, sys
|
|
|
|
| 200 |
return results_with_probs
|
| 201 |
|
| 202 |
@spaces.GPU
|
| 203 |
+
def classify_tokens(text: str, use_canned=False):
|
| 204 |
"""Main function for SNACS text classification that is called in the huggingface space
|
| 205 |
Input: string to be tagged
|
| 206 |
Output: HTML styled rendering of tagged outputs
|
|
|
|
| 230 |
"#9edae5"
|
| 231 |
][::-1] # reverse-sort to put the lighter colors first
|
| 232 |
|
| 233 |
+
if not use_canned:
|
| 234 |
+
model_name = "WesScivetti/SNACS_Multilingual"
|
| 235 |
+
|
| 236 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 237 |
+
model = AutoModelForTokenClassification.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else None)
|
| 238 |
+
# ONE pipeline; override aggregation per-call
|
| 239 |
+
pipe = MyPipeline(
|
| 240 |
+
model=model,
|
| 241 |
+
tokenizer=tokenizer,
|
| 242 |
+
device=0,
|
| 243 |
+
framework="pt"
|
| 244 |
+
)
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
# tagged spans
|
| 247 |
+
results_spans = pipe(text, aggregation_strategy="simple").sort(key=lambda x: x["start"])
|
| 248 |
+
|
| 249 |
+
# per-token + probabilities
|
| 250 |
+
results_tokens = pipe(text, aggregation_strategy="none", ignore_labels=[]).sort(key=lambda x: x["start"])
|
| 251 |
+
print(results_tokens)
|
| 252 |
+
else: # canned example to test the output display
|
| 253 |
+
text = "fox in socks"
|
| 254 |
+
results_spans = [{"start": 4, "end": 6, "entity_group": "p.Locus-p.Locus",
|
| 255 |
+
"score": 0.46, "word": "in"}]
|
| 256 |
+
results_tokens = [
|
| 257 |
+
{"start": 0, "end": 3, "entity": "O", "score": 1,
|
| 258 |
+
"probabilities": {"O": 1}},
|
| 259 |
+
{"start": 4, "end": 6, "entity": "B-p.Locus-p.Locus", "score": 0.46,
|
| 260 |
+
"probabilities": {"B-p.Locus-p.Locus": 0.46, "B-p.Circumstance-p.Circumstance": 0.3, "B-p.Circumstance-p.Locus": 0.2}},
|
| 261 |
+
{"start": 7, "end": 12, "entity": "O", "score": 1,
|
| 262 |
+
"probabilities": {"O": 1}}
|
| 263 |
+
]
|
| 264 |
|
| 265 |
# color helper that tolerates B-/I- prefixes
|
| 266 |
def pick_color(label: str, lbl2color: dict) -> str:
|
|
|
|
| 349 |
|
| 350 |
styled_html1 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output1}</div>"
|
| 351 |
styled_html2 = f"<div style='font-family:sans-serif;line-height:1.6;'>{output2}</div>"
|
| 352 |
+
|
| 353 |
+
simple_output_data = {"text": text, "entities": [{**e} | {"entity_group": display_label(e["entity_group"])} for e in results_spans]}
|
| 354 |
+
return simple_output_data, json.dumps(results_spans), json.dumps(results_tokens), styled_html1, table_html, styled_html2
|
| 355 |
# except Exception as e:
|
| 356 |
# # Force the real error into the Space logs
|
| 357 |
# import traceback, sys
|