Update app.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,9 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Auto
|
|
| 14 |
import os
|
| 15 |
import colorsys
|
| 16 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
|
| 19 |
hex_color = hex_color.lstrip('#')
|
|
@@ -86,26 +89,32 @@ def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str,
|
|
| 86 |
score = prediction1 / (prediction2 + prediction1)
|
| 87 |
|
| 88 |
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
|
| 89 |
-
|
| 90 |
-
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[
|
| 91 |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
|
| 92 |
entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
|
| 110 |
return fig1, fig2
|
| 111 |
|
|
|
|
| 14 |
import os
|
| 15 |
import colorsys
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
+
import plotly.graph_objects as go
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
import plotly.io as pio
|
| 20 |
|
| 21 |
def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
|
| 22 |
hex_color = hex_color.lstrip('#')
|
|
|
|
| 89 |
score = prediction1 / (prediction2 + prediction1)
|
| 90 |
|
| 91 |
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
|
| 92 |
+
|
| 93 |
+
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure]:
|
| 94 |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
|
| 95 |
entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
|
| 96 |
|
| 97 |
+
# Counting entities for binary classification
|
| 98 |
+
entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
|
| 99 |
+
bin_labels = list(entity_counts_bin.keys())
|
| 100 |
+
bin_sizes = list(entity_counts_bin.values())
|
| 101 |
|
| 102 |
+
# Counting entities for extended classification
|
| 103 |
+
entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
|
| 104 |
+
ext_labels = list(entity_counts_ext.keys())
|
| 105 |
+
ext_sizes = list(entity_counts_ext.values())
|
| 106 |
+
|
| 107 |
+
# Create pie chart for extended classification
|
| 108 |
+
fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3)])
|
| 109 |
+
fig1.update_layout(title_text='Extended Sequence Classification Subclasses')
|
| 110 |
|
| 111 |
+
# Create bar chart for binary classification
|
| 112 |
+
fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes)])
|
| 113 |
+
fig2.update_layout(
|
| 114 |
+
title='Binary Sequence Classification Classes',
|
| 115 |
+
xaxis_title='Entity Type',
|
| 116 |
+
yaxis_title='Count'
|
| 117 |
+
)
|
| 118 |
|
| 119 |
return fig1, fig2
|
| 120 |
|