Spaces:
Sleeping
Sleeping
File size: 6,361 Bytes
36d956f 95062a5 36d956f 95062a5 36d956f 95062a5 36d956f 95062a5 36d956f 95062a5 36d956f 95062a5 36d956f 95062a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import streamlit as st
import pandas as pd
import plotly.graph_objects as go
from src.predict import predict_demo
from src.front import render_html
from results.output import training_log, report_dict, report_dict_2, model_compare, data_compare
st.set_page_config(page_title="Vietnamese NER", layout="wide")
# ===== Main Title =====
st.title("🔍 Vietnamese Named Entity Recognition Demo")
# Tabs
tab1, tab2, tab3 = st.tabs(["📊 Data Analysis", "📈 Training Results", "🧪 Model Demo"])
# --- Tab 1: DATA ANALYSIS ---
with tab1:
col1, col2 = st.columns(2)
# ==== Distribution of NER Label Frequency ====
with col1:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ner_freq.png")
# ==== Distribution of NER Label Frequency (Add crawled data) ====
with col2:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ner_freq_add.png")
# ==== Distribution of the Number of Entities per Sentence (0 to 15+) ====
with col1:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ent_dis.png")
# ==== Distribution of Sentence Lengths ====
with col2:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/sent_len.png")
# ==== Distribution of Token Lengths ====
with col1:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/token_len.png")
# --- Tab 2: TRAINING RESULTS ---
with tab2:
st.set_page_config(
page_title="Vietnamese NER",
layout="wide",
initial_sidebar_state="expanded"
)
# ==== CREATE FIGURES ====
# 1️⃣ Loss
fig_loss = go.Figure()
fig_loss.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["train_loss"],
mode='lines+markers', name='Train Loss'))
fig_loss.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["val_loss"],
mode='lines+markers', name='Validation Loss'))
fig_loss.update_layout(title="Loss Curve", xaxis_title="Epoch", yaxis_title="Loss")
# 2️⃣ F1-Score
fig_f1 = go.Figure()
fig_f1.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["train_f1"],
mode='lines+markers', name='Train F1'))
fig_f1.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["val_f1"],
mode='lines+markers', name='Validation F1'))
fig_f1.update_layout(title="F1-Score Curve", xaxis_title="Epoch", yaxis_title="F1-Score")
# 3️⃣ Classification Report Table & Bar
labels = [k for k in report_dict.keys() if k not in ["accuracy", "macro avg", "weighted avg"]]
report_data = [[lbl,
report_dict[lbl]["precision"],
report_dict[lbl]["recall"],
report_dict[lbl]["f1-score"]]
for lbl in labels]
df_report = pd.DataFrame(report_data,
columns=["Label", "Precision", "Recall", "F1-Score"])
fig_report = go.Figure()
for col in ["Precision", "Recall", "F1-Score"]:
fig_report.add_trace(go.Bar(x=df_report["Label"], y=df_report[col], name=col))
fig_report.update_layout(barmode='group',
title="Class Metrics: PhoBERT + CRF",
xaxis_title="Label", yaxis_title="Score",
yaxis=dict(range=[0, 1.0]))
labels2 = [k for k in report_dict_2.keys() if k not in ["accuracy", "macro avg", "weighted avg"]]
report_data2 = [[lbl,
report_dict_2[lbl]["precision"],
report_dict_2[lbl]["recall"],
report_dict_2[lbl]["f1-score"]]
for lbl in labels2]
df_report2 = pd.DataFrame(report_data2,
columns=["Label", "Precision", "Recall", "F1-Score"])
fig_report2 = go.Figure()
for col in ["Precision", "Recall", "F1-Score"]:
fig_report2.add_trace(go.Bar(x=df_report2["Label"], y=df_report2[col], name=col))
fig_report2.update_layout(barmode='group',
title="Class Metrics: PhoBERT + Softmax",
xaxis_title="Label", yaxis_title="Score",
yaxis=dict(range=[0, 1.0]))
# 4️⃣ Model & Data Comparison Tables
df_model = pd.DataFrame(
[[m, v["F1"], v["Accuracy"]] for m, v in model_compare["Data"].items()],
columns=["Model", "F1-Score", "Accuracy"]
)
df_data = pd.DataFrame(
[[s, f1] for s, f1 in data_compare["Data"].items()],
columns=["Preprocessing", "F1-Score"]
)
# ==== CLEAN LAYOUT WITH COLUMNS ====
# Row 1: Loss | F1
col1, col2 = st.columns(2)
with col1:
st.plotly_chart(fig_loss, use_container_width=True)
with col2:
st.plotly_chart(fig_f1, use_container_width=True)
# Row 2: Class Report Table | Bar Chart
col3, col4 = st.columns(2)
with col3:
st.plotly_chart(fig_report2, use_container_width=True)
with col4:
st.plotly_chart(fig_report, use_container_width=True)
# Row 3: Model Compare | Data Compare
col5, col6 = st.columns(2)
with col5:
st.markdown("**Model Comparison**")
st.dataframe(df_model, use_container_width=True)
with col6:
st.markdown("**Data Preprocessing Comparison**")
st.dataframe(df_data, use_container_width=True)
# --- Tab 3: MODEL DEMO ---
with tab3:
text = st.text_input("Enter Vietnamese text:", "Nguyễn Văn A đang làm việc tại Hà Nội")
if st.button("Analyze"):
if not text.strip():
st.warning("Please enter some text!")
else:
tokens, labels = predict_demo(text)
st.subheader("Detected Entities")
entities = [(tok, lab) for tok, lab in zip(tokens, labels) if lab != "O"]
if entities:
for tok, lab in entities:
st.markdown(f"🔹 **{tok}** — *{lab}*")
else:
st.info("No named entities detected.")
st.subheader("Highlighted Text")
st.markdown(render_html(tokens, labels), unsafe_allow_html=True)
|