GitHub Actions
Auto-deploy from GitHub (binary files removed)
95062a5
import streamlit as st
import pandas as pd
import plotly.graph_objects as go
from src.predict import predict_demo
from src.front import render_html
from results.output import training_log, report_dict, report_dict_2, model_compare, data_compare
st.set_page_config(page_title="Vietnamese NER", layout="wide")
# ===== Main Title =====
st.title("🔍 Vietnamese Named Entity Recognition Demo")
# Tabs
tab1, tab2, tab3 = st.tabs(["📊 Data Analysis", "📈 Training Results", "🧪 Model Demo"])
# --- Tab 1: DATA ANALYSIS ---
with tab1:
col1, col2 = st.columns(2)
# ==== Distribution of NER Label Frequency ====
with col1:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ner_freq.png")
# ==== Distribution of NER Label Frequency (Add crawled data) ====
with col2:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ner_freq_add.png")
# ==== Distribution of the Number of Entities per Sentence (0 to 15+) ====
with col1:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/ent_dis.png")
# ==== Distribution of Sentence Lengths ====
with col2:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/sent_len.png")
# ==== Distribution of Token Lengths ====
with col1:
st.image("https://raw.githubusercontent.com/duclld1709/vietnamese-ner/refs/heads/main/results/token_len.png")
# --- Tab 2: TRAINING RESULTS ---
with tab2:
st.set_page_config(
page_title="Vietnamese NER",
layout="wide",
initial_sidebar_state="expanded"
)
# ==== CREATE FIGURES ====
# 1️⃣ Loss
fig_loss = go.Figure()
fig_loss.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["train_loss"],
mode='lines+markers', name='Train Loss'))
fig_loss.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["val_loss"],
mode='lines+markers', name='Validation Loss'))
fig_loss.update_layout(title="Loss Curve", xaxis_title="Epoch", yaxis_title="Loss")
# 2️⃣ F1-Score
fig_f1 = go.Figure()
fig_f1.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["train_f1"],
mode='lines+markers', name='Train F1'))
fig_f1.add_trace(go.Scatter(x=training_log["epoch"], y=training_log["val_f1"],
mode='lines+markers', name='Validation F1'))
fig_f1.update_layout(title="F1-Score Curve", xaxis_title="Epoch", yaxis_title="F1-Score")
# 3️⃣ Classification Report Table & Bar
labels = [k for k in report_dict.keys() if k not in ["accuracy", "macro avg", "weighted avg"]]
report_data = [[lbl,
report_dict[lbl]["precision"],
report_dict[lbl]["recall"],
report_dict[lbl]["f1-score"]]
for lbl in labels]
df_report = pd.DataFrame(report_data,
columns=["Label", "Precision", "Recall", "F1-Score"])
fig_report = go.Figure()
for col in ["Precision", "Recall", "F1-Score"]:
fig_report.add_trace(go.Bar(x=df_report["Label"], y=df_report[col], name=col))
fig_report.update_layout(barmode='group',
title="Class Metrics: PhoBERT + CRF",
xaxis_title="Label", yaxis_title="Score",
yaxis=dict(range=[0, 1.0]))
labels2 = [k for k in report_dict_2.keys() if k not in ["accuracy", "macro avg", "weighted avg"]]
report_data2 = [[lbl,
report_dict_2[lbl]["precision"],
report_dict_2[lbl]["recall"],
report_dict_2[lbl]["f1-score"]]
for lbl in labels2]
df_report2 = pd.DataFrame(report_data2,
columns=["Label", "Precision", "Recall", "F1-Score"])
fig_report2 = go.Figure()
for col in ["Precision", "Recall", "F1-Score"]:
fig_report2.add_trace(go.Bar(x=df_report2["Label"], y=df_report2[col], name=col))
fig_report2.update_layout(barmode='group',
title="Class Metrics: PhoBERT + Softmax",
xaxis_title="Label", yaxis_title="Score",
yaxis=dict(range=[0, 1.0]))
# 4️⃣ Model & Data Comparison Tables
df_model = pd.DataFrame(
[[m, v["F1"], v["Accuracy"]] for m, v in model_compare["Data"].items()],
columns=["Model", "F1-Score", "Accuracy"]
)
df_data = pd.DataFrame(
[[s, f1] for s, f1 in data_compare["Data"].items()],
columns=["Preprocessing", "F1-Score"]
)
# ==== CLEAN LAYOUT WITH COLUMNS ====
# Row 1: Loss | F1
col1, col2 = st.columns(2)
with col1:
st.plotly_chart(fig_loss, use_container_width=True)
with col2:
st.plotly_chart(fig_f1, use_container_width=True)
# Row 2: Class Report Table | Bar Chart
col3, col4 = st.columns(2)
with col3:
st.plotly_chart(fig_report2, use_container_width=True)
with col4:
st.plotly_chart(fig_report, use_container_width=True)
# Row 3: Model Compare | Data Compare
col5, col6 = st.columns(2)
with col5:
st.markdown("**Model Comparison**")
st.dataframe(df_model, use_container_width=True)
with col6:
st.markdown("**Data Preprocessing Comparison**")
st.dataframe(df_data, use_container_width=True)
# --- Tab 3: MODEL DEMO ---
with tab3:
text = st.text_input("Enter Vietnamese text:", "Nguyễn Văn A đang làm việc tại Hà Nội")
if st.button("Analyze"):
if not text.strip():
st.warning("Please enter some text!")
else:
tokens, labels = predict_demo(text)
st.subheader("Detected Entities")
entities = [(tok, lab) for tok, lab in zip(tokens, labels) if lab != "O"]
if entities:
for tok, lab in entities:
st.markdown(f"🔹 **{tok}** — *{lab}*")
else:
st.info("No named entities detected.")
st.subheader("Highlighted Text")
st.markdown(render_html(tokens, labels), unsafe_allow_html=True)