demo / src /pages /Test_Evaluation.py
ElmiraManavi
display tests timeline; refactorings
131b6cd
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import streamlit as st
from bson import ObjectId
from html_to_markdown import convert
from services import init_connection
st.set_page_config(layout="wide")
st.markdown(
"""
<style>
.block-container {
width: 80vw;
max-width: 1400px;
margin: 0 auto;
}
</style>
""",
unsafe_allow_html=True,
)
def print_schedule_obj(s):
start_date = s.get("start_date")
if not start_date:
start_date = s.get("startdate")
start_date_str = start_date.strftime("%d.%m.%Y") if start_date else ""
end_date = s.get("end_date")
if not end_date:
end_date = s.get("enddate")
end_date_str = end_date.strftime("%d.%m.%Y") if end_date else ""
start_time = s.get("start_time")
if not start_time:
start_time = s.get("starttime")
start_time_str = start_time.strftime("%H:%M") if start_time else ""
end_time = s.get("end_time")
if not end_time:
end_time = s.get("endtime")
end_time_str = end_time.strftime("%H:%M") if end_time else ""
return f"{start_date_str} - {end_date_str} | {start_time_str} - {end_time_str}\n\n"
def create_data_metrics_df(overall_metrics: dict) -> pd.DataFrame:
field_order = ["page_type", "title", "locations", "schedule", "start_date", "end_date", "start_time", "end_time"]
rows = {key: "" for key in field_order if key in overall_metrics}
for field, metrics in overall_metrics.items():
cleaned_metrics = metrics.copy()
cleaned_metrics = {k: v for k, v in cleaned_metrics.items() if isinstance(v, float)}
rows[field] = cleaned_metrics
df = pd.DataFrame(rows).T
return df
def create_confusion_matrix(overall_metrics: dict):
page_type_metric = overall_metrics.get("page_type", {})
tp_count = page_type_metric.get("tp", 0)
tn_count = page_type_metric.get("tn", 0)
fp_count = page_type_metric.get("fp", 0)
fn_count = page_type_metric.get("fn", 0)
cm = pd.DataFrame(
[[tn_count, fp_count],
[fn_count, tp_count]],
index=['Expected NO_EVENT', 'Expected EVENT'],
columns=['Predicted NO_EVENT', 'Predicted EVENT']
)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Expected')
return fig
def create_fn_df(record_results: dict):
false_negatives = [v.get("data", {}).get("reason") for v in record_results.values() if
v.get("record_metrics", {}).get("page_type", {}).get("fn") == 1]
fn_counts = pd.Series(false_negatives).value_counts()
fn_percent = fn_counts / len(record_results) * 100
df = pd.DataFrame({
'Rejected Reason': fn_counts.index,
'Anzahl': fn_counts.values,
'Prozent': fn_percent.values
})
total_row = pd.DataFrame([{
'Rejected Reason': 'Gesamt',
'Anzahl': fn_counts.sum(),
'Prozent': fn_percent.sum()
}])
df = pd.concat([df, total_row], ignore_index=True)
df = df.style.format({'Prozent': '{:.1f}%'})
return df
def create_error_df(overall_metrics: dict, batchsize: int):
rows = []
for k, v in overall_metrics.get("error", {}).items():
rows.append({
'Error': k.upper(),
'Anzahl': v,
'Prozent': v / batchsize * 100
})
df = pd.DataFrame(rows)
return df
def create_sunburst_chart(overall_metrics: dict, batchsize: int):
page_type_metrics = overall_metrics.get("page_type", {})
tp = page_type_metrics.get("tp", 0)
fn = page_type_metrics.get("fn", 0)
fp = page_type_metrics.get("fp", 0)
tn = page_type_metrics.get("tn", 0)
error = batchsize - tp - fn - fp - tn
correct = tp + tn
incorrect = fp + fn
error_df = create_error_df(overall_metrics, batchsize)
labels = ["Gesamt", "Korrekt", "Falsch", "Error", "True Positive", "True Negative", "False Positive",
"False Negative"]
parents = ["", "Gesamt", "Gesamt", "Gesamt", "Korrekt", "Korrekt", "Falsch", "Falsch"]
values = [batchsize, correct, incorrect, error, tp, tn, fp, fn]
for i, row in error_df.iterrows():
labels.append(row['Error'])
parents.append("Error")
values.append(row['Anzahl'])
colors = ["#FFFFFF", "#7FD1B9", "#FFB284", "#FF8585", "#5BC0BE", "#379683", "#F2881A", "#F7B32B"]
colors.extend(["#FF8585"] * len(error_df))
fig = go.Figure(go.Sunburst(
labels=labels,
parents=parents,
values=values,
branchvalues="total",
marker=dict(colors=colors),
hovertemplate='<b>%{label}</b><br>Anzahl: %{value}<br>Prozent: %{percentParent:.1%}<extra></extra>'
))
fig.update_layout(margin=dict(t=0, b=0, l=0, r=0))
return fig
def create_page_type_chart(overall_metrics: dict):
page_type_metrics = overall_metrics.get("page_type", {})
df = pd.DataFrame([{
"precision": page_type_metrics.get("precision", 0),
"recall": page_type_metrics.get("recall", 0),
"f1": page_type_metrics.get("f1", 0),
"accuracy": page_type_metrics.get("accuracy", 0),
"effective_accuracy": page_type_metrics.get("effective_accuracy", 0)
}]).T
return df
def create_detail_table(test: dict):
def _stringify(v):
if isinstance(v, list):
return ", ".join(map(str, v))
if isinstance(v, dict):
return str(v)
return str(v)
rows = []
meta_columns = set()
for record_id, result in test.get("record_results", {}).items():
validation = db.testdata_1.find_one({"_id": ObjectId(record_id)})
expected = validation.get("data", {})
predicted = result.get("data", {})
metrics = result.get("record_metrics", {})
meta = result.get("meta", {})
# Meta columns sammeln
meta_columns |= {f"Meta - {k}" for k in meta}
# Kopfzeile pro Datensatz
head = {
"Record ID": str(record_id),
"Field": "",
"Expected": "",
"Predicted": "",
"Metrics": "",
**{f"Meta - {k}": _stringify(meta.get(k, "")) for k in meta}
}
rows.append(head)
def add(field, exp, pred):
val = metrics.get(field)
if isinstance(val, dict):
val = " | ".join(f"{k}: {v:.2f}" for k, v in val.items())
rows.append({
"Record ID": "",
"Field": field,
"Expected": exp,
"Predicted": pred,
"Metrics": val if val is not None else "",
**{col: "" for col in meta_columns}
})
add("page_type", validation.get("page_type"), result.get("page_type"))
add("title", expected.get("title"), predicted.get("title"))
add(
"schedule",
"\n\n".join(print_schedule_obj(s) for s in expected.get("schedule", [])),
"\n\n".join(print_schedule_obj(s) for s in predicted.get("schedule", []))
)
add(
"locations",
", ".join(g.get("geolocation", {}).get("formatted", "") for g in expected.get("locations", [])),
", ".join(g.get("geolocation", {}).get("formatted", "") for g in predicted.get("locations", []))
)
rows.append({col: "" for col in ["Record ID", "Field", "Expected", "Predicted", "Metrics", *meta_columns]})
return pd.DataFrame(rows)
def create_event_score_chart(test: dict):
event_scores = [r["record_metrics"].get("event_score") for r in test["record_results"].values() if
r["record_metrics"].get("event_score") is not None]
mean_score = test.get("overall_metrics", {}).get("event_score", 0)
fig = go.Figure()
fig.add_trace(go.Histogram(
x=event_scores,
name='control',
xbins=dict(
start=0.0,
end=1.1,
size=0.1
),
marker=dict(
color="#43cd80",
line=dict(color='white', width=1) # Trennung zwischen Balken
),
))
fig.update_layout(
xaxis=dict(tickvals=[i / 10 for i in range(11)]),
yaxis_title="Anzahl Events",
xaxis_title="Event Score",
title="Event Score",
annotations=[
dict(
x=0.02,
y=0.94,
xref="paper",
yref="paper",
text=f"Ø Event Score: {mean_score:.2f}",
showarrow=False,
align="left",
font=dict(size=13),
bgcolor="rgba(255,255,255,0.8)",
bordercolor="#ccc",
borderwidth=1
)
]
)
return fig
@st.dialog("Original Seite", width="medium")
def show_website(url, html):
st.info(f"Link zur Original Website: {url}")
md = convert(html)
st.write(md)
st.title("Test Evaluation")
db = init_connection()
tests = list(db.test_evaluation.find({}, {"_id": 1, "status": 1, "created_at": 1}))
tests_sorted = sorted(tests, key=lambda t: t["created_at"], reverse=True)
options = {str(t["_id"]): f"{t['status']} - {t['created_at'].strftime('%Y-%m-%d %H:%M:%S')}" for t in tests_sorted}
selected_id = st.selectbox("Wähle einen Test aus", options=list(options.keys()), format_func=lambda x: options[x])
if selected_id:
test = db.test_evaluation.find_one({"_id": ObjectId(selected_id)})
record_results = test.get("record_results", {})
batchsize = len(record_results)
st.success(
f"**Test ID:** {selected_id} | "
f"**Status:** {test.get('status')} | "
f"**Batchsize:** {batchsize}"
)
overall_metrics = test.get("overall_metrics", {})
if overall_metrics:
df_data_metrics = create_data_metrics_df(overall_metrics.get("event_metrics", {}))
cm_fig = create_confusion_matrix(overall_metrics)
df_fn = create_fn_df(record_results)
df_error = create_error_df(overall_metrics, batchsize)
fig_event_score = create_event_score_chart(test)
overall_event_score = overall_metrics.get("event_score", {})
page_type_suburst_chart = create_sunburst_chart(overall_metrics, batchsize)
page_type_metrics = create_page_type_chart(overall_metrics)
st.write("## Page Type Metriken")
st.write(
"Klassifikation einer Website als Event- oder Nicht-Event-Seite während der Pipeline.")
col1, col2 = st.columns([2, 1.5])
with col1:
st.plotly_chart(page_type_suburst_chart)
with col2:
st.write("")
st.write("")
st.markdown("""
<span style="font-size:12px">
<span style="color:#5BC0BE">■</span> <b>True Positive (TP):</b> Event-Seite korrekt erkannt<br>
<span style="color:#379683">■</span> <b>True Negative (TN):</b> Nicht-Event korrekt erkannt<br>
<span style="color:#F2881A">■</span> <b>False Positive (FP):</b> Nicht-Event fälschlich als Event erkannt<br>
<span style="color:#F7B32B">■</span> <b>False Negative (FN):</b> Event-Seite nicht erkannt<br>
<span style="color:#FF8585">■</span> <b>Error:</b> Fehler während Verarbeitung<br>
&nbsp;&nbsp;<span style="color:#FF8585">●</span> RATE_LIMIT_ERROR: LLM API-Limit erreicht<br>
&nbsp;&nbsp;<span style="color:#FF8585">●</span> INVALID_EVENT: Event extrahiert, relevante Daten fehlten<br>
&nbsp;&nbsp;<span style="color:#FF8585">●</span> INVALID_FORMAT: Event extrahiert, aber falsches JSON<br>
&nbsp;&nbsp;<span style="color:#FF8585">●</span> ERROR: Andere Fehlerarten
</span>
""", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.write("#### Confusion Matrix")
st.pyplot(cm_fig, width=450)
with col2:
st.write("#### Gründe für False Negatives")
st.dataframe(df_fn)
col1,col2 = st.columns([2, 1])
with col1:
st.write("#### Scores")
st.bar_chart(page_type_metrics, height=450)
with col2:
st.space(size=100)
st.markdown("""
<span style="font-size:12px">
<span style="color:#5BC0BE">■</span> <b>Accuracy:</b> Anteil korrekt klassifizierter Seiten an allen klassifizierten Seiten (ohne Errors)<br>
<span style="color:#379683">■</span> <b>Effective Accuracy:</b> Anteil korrekt klassifizierter Seiten bezogen auf alle Testergebnisse (mit Errors)<br>
<span style="color:#F2881A">■</span> <b>F1:</b> Harmonic Mean aus Precision und Recall<br>
<span style="color:#F7B32B">■</span> <b>Precision:</b> Anteil der als Event erkannten Seiten, die tatsächlich Events sind<br>
<span style="color:#FF8585">■</span> <b>Recall:</b> Anteil der tatsächlichen Event-Seiten, die korrekt erkannt wurden<br>
</span>
""", unsafe_allow_html=True)
st.write("---")
st.write("## Event-Metriken")
st.write("Qualität und Korrektheit der extrahierten Event-Informationen.")
col1, col2 = st.columns([1, 2])
with col1:
st.space(size=100)
st.markdown("""
<span style="font-size:12px">
<span style="color:#43cd80">■</span> <b>Event Score:</b> Gesamtbewertung der Event-Qualität, berechnet aus F1-Score und Match Scores der einzelnen Felder<br>
<span style="color:#ff2b2b">■</span> <b>Precision:</b> Anteil korrekt extrahierter Informationen<br>
<span style="color:#ffabab">■</span> <b>Recall:</b> Anteil erkannter Informationen von allen erwarteten<br>
<span style="color:#0068c9">■</span> <b>F1-Score:</b> Harmonisches Mittel aus Precision und Recall<br>
<span style="color:#83c9ff">■</span> <b>Match Score:</b> Textähnlichkeit zweier Strings (Fuzzy Matching)
</span>
""", unsafe_allow_html=True, width=300)
with col2:
st.plotly_chart(fig_event_score)
st.write("**Ergebnisse der einzelnen Event-Informationen**")
st.bar_chart(df_data_metrics, stack=False, sort=False)
else:
st.info("Der Test läuft noch. Es konnte noch keine Metric erstellt werden")
with st.expander("Testergebnisse im Detail"):
df = create_detail_table(test)
st.dataframe(df, height=600)
record_id = st.text_input(label="Gebe eine Record ID ein um die Original Website anzusehen.", value="")
if record_id:
record = db.testdata_1.find_one({"_id": ObjectId(record_id)})
html = record.get("html")
url = record.get("url")
if html:
html = html.decode("utf-8")
show_website(url, html)
with st.expander("Ergebnisse aller Tests im Verlauf"):
pipeline = [
{"$match": {"status": "completed"}},
{"$project": {
"_id": 1,
"created_at": 1,
"overall_metrics": 1,
"pipeline_version": 1,
"batchsize": {
"$size": {
"$objectToArray": {
"$ifNull": ["$record_results", {}]
}
}
}
}}
]
tests = list(db.test_evaluation.aggregate(pipeline))
if not tests:
st.info("Es sind noch keine Testergebnisse vorhanden.")
else:
event_scores_time_series = pd.DataFrame([
{
"timestamp": pd.to_datetime(t.get("created_at")),
"pipeline_version": t.get("pipeline_version"),
"event_score": t.get("overall_metrics", {}).get("event_score"),
"errors": sum(t.get("overall_metrics", {}).get("error", {"error": 90}).values()) / t.get(
"batchsize") * 100,
"page_type_effective_accuracy": t.get("overall_metrics", {}).get("page_type", {}).get(
"effective_accuracy", 0) * 100,
"page_type_precision": t.get("overall_metrics", {}).get("page_type", {}).get("precision", 0) * 100,
"page_type_recall": t.get("overall_metrics", {}).get("page_type", {}).get("recall", 0) * 100,
"page_type_f1": t.get("overall_metrics", {}).get("page_type", {}).get("f1", 0) * 100,
"page_type_accuracy": t.get("overall_metrics", {}).get("page_type", {}).get("accuracy", 0) * 100,
}
for t in tests
])
event_scores_time_series = (
event_scores_time_series
.sort_values("timestamp")
.set_index("timestamp")
)
df = event_scores_time_series.reset_index()
fig = px.line(
df,
x="timestamp",
y=["event_score", "errors", "page_type_effective_accuracy", "page_type_precision", "page_type_recall",
"page_type_f1",
"page_type_accuracy"],
hover_data=["pipeline_version"],
labels={
"value": "Prozent",
"variable": "Metrik"
},
markers=True
)
fig.update_yaxes(tick0=0, dtick=10, title="Wert in Prozent")
st.plotly_chart(fig, use_container_width=True)