open-model-evolution / graphs /model_characteristics.py
emsesc's picture
refactor code
b466419
raw
history blame
3.8 kB
import plotly.graph_objects as go
import plotly.express as px
def create_concentration_chart(
df,
period_col,
metric_col,
value_col,
order,
palette
):
fig = go.Figure()
# Create stacked area traces
for i, metric in enumerate(order):
metric_data = df[df[metric_col] == metric]
# Sort by time and get values
metric_data = metric_data.sort_values(period_col)
x_vals = metric_data[period_col]
y_vals = metric_data[value_col]
# Add area trace
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
name=metric,
mode='lines',
line=dict(width=0),
fill='tonexty' if i > 0 else 'tozeroy',
fillcolor=palette[i % len(palette)],
stackgroup='one',
hovertemplate='<b>%{fullData.name}</b><br>' +
'Time: %{x}<br>' +
'Value: %{y}<extra></extra>'
)
)
fig.update_layout(
autosize=True,
font_size=14,
showlegend=True,
legend=dict(
title="Language Concentration",
orientation="v",
yanchor="top",
y=1,
xanchor="left",
x=1.02
),
margin=dict(l=60, r=150, t=40, b=60), # Extra right margin for legend
plot_bgcolor='white',
hovermode='x unified'
)
fig.update_xaxes(
title_text="",
showgrid=True,
gridcolor='lightgray',
gridwidth=1
)
fig.update_yaxes(
title_text="",
showgrid=True,
gridcolor='lightgray',
gridwidth=1
)
return fig
def create_line_plot(
df,
plot_choices,
color_palette=None
):
fig = go.Figure()
groups = df['status'].unique()
if color_palette is None:
color_palette = px.colors.qualitative.Set1
for i, group in enumerate(groups):
group_data = df[df['status'] == group]
group_data = group_data.sort_values('period')
x_vals = group_data['period']
y_vals = group_data[plot_choices["y_col"]]
if plot_choices.get("y_format") == "percent":
y_vals = y_vals * 100
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
name=group,
mode='lines',
line=dict(
color=color_palette[i % len(color_palette)],
width=3
),
opacity=0.85,
hovertemplate='<b>%{fullData.name}</b><br>' +
'Period: %{x}<br>' +
'Value: %{y:.2f}%<extra></extra>' if plot_choices.get("y_format") == "percent"
else '<b>%{fullData.name}</b><br>Period: %{x}<br>Value: %{y}<extra></extra>'
)
)
fig.update_layout(
width=1125,
height=225,
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
margin=dict(l=60, r=60, t=60, b=60),
plot_bgcolor='white',
hovermode='x unified'
)
fig.update_xaxes(
title_text="Period",
showgrid=False,
zeroline=False
)
y_title = plot_choices["y_col"]
if plot_choices.get("y_format") == "percent":
y_title += " (%)"
fig.update_yaxes(
title_text=y_title,
showgrid=False,
zeroline=False,
type='log' if plot_choices.get("y_log") else 'linear'
)
return fig