open-model-evolution / graphs /model_market_share.py
emsesc's picture
refactor code
b466419
raw
history blame
16.4 kB
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
filtered_df = pd.read_pickle("data_frames/filtered_df.pkl")
def create_stacked_area_chart(
topk_df, gini_df, hhi_df, events, palette, start_time=None, end_time=None
):
# Create subplot with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Define metric order
metric_order = [
"Top 1",
"Top 1 - 10",
"Top 10 - 100",
"Top 100 - 1000",
"Top 1000 - 10000",
"Rest",
]
# Create stacked area traces
for i, metric in enumerate(metric_order):
metric_data = topk_df[topk_df["metric"] == metric]
# Sort by time and get values
metric_data = metric_data.sort_values("time")
if start_time:
metric_data = metric_data[metric_data["time"] >= start_time]
if end_time:
metric_data = metric_data[metric_data["time"] <= end_time]
x_vals = metric_data["time"]
y_vals = metric_data["value"]
# Add area trace
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
name=metric,
mode="lines",
line=dict(width=0, color=palette[i % len(palette)]),
fill="tonexty" if i > 0 else "tozeroy",
fillcolor=palette[i % len(palette)],
stackgroup="one",
hovertemplate="<b>%{fullData.name}</b><br>"
+ "Time: %{x}<br>"
+ "Value: %{y}<extra></extra>",
),
secondary_y=False,
)
# Add overlay lines
# Gini Coefficient
gini_data = gini_df.sort_values("time")
if start_time:
gini_data = gini_data[gini_data["time"] >= start_time]
if end_time:
gini_data = gini_data[gini_data["time"] <= end_time]
fig.add_trace(
go.Scatter(
x=gini_data["time"],
y=gini_data["value"],
name="Gini Coefficient",
mode="lines",
line=dict(color="#6b46c1", width=3),
yaxis="y2",
hovertemplate="<b>Gini Coefficient</b><br>"
+ "Time: %{x}<br>"
+ "Value: %{y:.3f}<extra></extra>",
),
secondary_y=True,
)
# HHI (×10)
hhi_data = hhi_df.sort_values("time")
if start_time:
hhi_data = hhi_data[hhi_data["time"] >= start_time]
if end_time:
hhi_data = hhi_data[hhi_data["time"] <= end_time]
fig.add_trace(
go.Scatter(
x=hhi_data["time"],
y=hhi_data["value"] * 10,
name="HHI (×10)",
mode="lines",
line=dict(color="#ec4899", width=3),
yaxis="y2",
hovertemplate="<b>HHI (×10)</b><br>"
+ "Time: %{x}<br>"
+ "Value: %{y:.3f}<extra></extra>",
),
secondary_y=True,
)
# Add vertical lines for events
for event_name, event_date in events.items():
fig.add_shape(
type="line",
x0=event_date,
x1=event_date,
y0=0,
y1=1,
yref="paper",
line=dict(color="#333333", width=2, dash="dash"),
)
# Add annotation for the event
fig.add_annotation(
x=event_date,
y=1,
yref="paper",
text=event_name,
showarrow=False,
yshift=10,
font=dict(size=12),
)
fig.update_layout(
autosize=True,
font_size=14,
showlegend=True,
margin=dict(l=60, r=60, t=40, b=60),
plot_bgcolor="white",
hovermode="x unified",
)
# Update x-axis to be governed by start_time/end_time
xaxis_range = None
if start_time is not None and end_time is not None:
xaxis_range = [start_time, end_time]
elif start_time is not None:
xaxis_range = [start_time, None]
elif end_time is not None:
xaxis_range = [None, end_time]
fig.update_xaxes(
title_text="",
showgrid=True,
gridcolor="lightgray",
gridwidth=1,
range=xaxis_range,
)
# Update primary y-axis (left)
fig.update_yaxes(
title_text="Model Market Share",
showgrid=True,
gridcolor="lightgray",
gridwidth=1,
secondary_y=False,
)
# Update secondary y-axis (right)
fig.update_yaxes(
title_text="Concentration Indices", showgrid=False, secondary_y=True
)
return fig
def create_world_map(
df, time_col="time", metric_col="metric", value_col="value", top_n_labels=10, start_time=None, end_time=None
):
# Get all unique times and sort them
times = sorted(df[time_col].unique())
# Country code mapping
country_code_map = {
"Germany": "DEU",
"United States of America": "USA",
"China": "CHN",
"France": "FRA",
"India": "IND",
"Israel": "ISR",
"South Korea": "KOR",
"United Kingdom": "GBR",
"Switzerland": "CHE",
"United Arab Emirates": "ARE",
"Vietnam": "VNM",
"Singapore": "SGP",
"Chile": "CHL",
"Hong Kong": "HKG",
"Japan": "JPN",
"Canada": "CAN",
"Spain": "ESP",
"Finland": "FIN",
"Indonesia": "IDN",
"Russia": "RUS",
"Iran": "IRN",
"Belarus": "BLR",
"Thailand": "THA",
"UAE": "ARE",
"Argentina": "ARG",
"Iceland": "ISL",
"Poland": "POL",
"Sweden": "SWE",
"Taiwan": "TWN",
"Lebanon": "LBN",
"Algeria": "DZA",
"Bulgaria": "BGR",
"Norway": "NOR",
"Netherlands": "NLD",
"Hungary": "HUN",
"Estonia": "EST",
"Qatar": "QAT",
"Brazil": "BRA",
"Morocco": "MAR",
"Slovenia": "SVN",
"Ghana": "GHA",
"Uganda": "UGA",
"Turkey": "TUR",
}
df["country_code"] = df[metric_col].map(country_code_map)
mapped_data = df.dropna(subset=["country_code"])
fig = make_subplots(
rows=1,
cols=1,
specs=[[{"type": "geo"}]],
)
# Function to aggregate data for time range
def aggregate_time_range(start_time, end_time):
range_data = mapped_data[
(mapped_data[time_col] >= start_time) & (mapped_data[time_col] <= end_time)
]
# Average values across time range
agg_data = (
range_data.groupby([metric_col, "country_code"])[value_col]
.mean()
.reset_index()
)
agg_data["percentage"] = agg_data[value_col] * 100
return agg_data.sort_values("percentage", ascending=False)
# Initial data if start or end time are not set (full range)
if start_time is None:
start_time = times[0]
if end_time is None:
end_time = times[-1]
initial_data = aggregate_time_range(start_time, end_time)
# top_countries = initial_data.head(top_n_labels)
# Create hover text
hover_text = []
for _, row in initial_data.iterrows():
hover_text.append(
f"<b>{row[metric_col]}</b><br>"
f"Avg Downloads: {row['percentage']:.1f}% of total<br>"
f"Avg Value: {row[value_col]:.6f}"
)
# Add choropleth to plot
fig.add_trace(
go.Choropleth(
locations=initial_data["country_code"],
z=initial_data["percentage"],
text=hover_text,
hovertemplate="%{text}<extra></extra>",
colorscale=[
"#001219",
"#0a9396",
"#94d2bd",
"#e9d8a6",
"#ee9b00",
"#ca6702",
"#bb3e03",
"#9b2226",
],
colorbar=dict(
title="Avg % of Total Downloads",
tickfont=dict(size=12),
len=0.6,
x=1.02,
y=0.7,
),
marker_line_color="#ffffff",
marker_line_width=1.5,
geo="geo",
),
row=1,
col=1,
)
# Country center coordinates for labels
# country_centers = {
# "USA": {"lat": 39.8, "lon": -98.5},
# "CHN": {"lat": 35.8, "lon": 104.2},
# "DEU": {"lat": 51.2, "lon": 10.4},
# "GBR": {"lat": 55.4, "lon": -3.4},
# "FRA": {"lat": 46.6, "lon": 2.2},
# "JPN": {"lat": 36.2, "lon": 138.3},
# "IND": {"lat": 20.6, "lon": 78.9},
# "CAN": {"lat": 56.1, "lon": -106.3},
# "RUS": {"lat": 61.5, "lon": 105.3},
# "BRA": {"lat": -14.2, "lon": -51.9},
# "AUS": {"lat": -25.3, "lon": 133.8},
# "KOR": {"lat": 35.9, "lon": 127.8},
# }
# # Add initial labels using scattergeo instead of annotations
# label_lons = []
# label_lats = []
# label_texts = []
# for _, country in top_countries.iterrows():
# country_code = country["country_code"]
# if country_code in country_centers:
# center = country_centers[country_code]
# label_lons.append(center["lon"])
# label_lats.append(center["lat"])
# label_texts.append(f"{country['percentage']:.1f}%")
# # Add text labels as a scattergeo trace
# fig.add_trace(
# go.Scattergeo(
# lon=label_lons,
# lat=label_lats,
# text=label_texts,
# mode="text",
# textfont=dict(
# color="#ffffff", size=13, family="Inter, system-ui, sans-serif"
# ),
# textposition="middle center",
# showlegend=False,
# hoverinfo="skip",
# geo="geo",
# ),
# row=1,
# col=1,
# )
# Update layout
fig.update_layout(
title=dict(
text="Model Downloads by Country",
x=0.5,
font=dict(size=20),
),
width=1200,
height=800,
plot_bgcolor="#ffffff",
paper_bgcolor="#ffffff",
margin=dict(l=0, r=120, t=100, b=60),
)
# Update geo layout
fig.update_geos(
showframe=False,
showland=True,
landcolor="#d0cfcf",
coastlinecolor="#b8b8b8",
projection_type="natural earth",
bgcolor="#ffffff",
)
return fig
def create_range_slider(df):
if df.empty or "time" not in df.columns:
return go.Figure()
times = sorted(df["time"].unique())
fig = go.Figure()
# Invisible trace just to attach slider to the x-axis
fig.add_trace(
go.Scatter(
x=times,
y=[0] * len(times),
mode="lines",
line=dict(color="rgba(0,0,0,0)"), # Invisible line
hoverinfo="skip",
showlegend=False
)
)
# Enable range slider
fig.update_layout(
xaxis=dict(
rangeslider=dict(visible=False),
type="date"
),
yaxis=dict(visible=False),
margin=dict(t=20, b=20, l=20, r=20),
height=100
)
return fig
def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_time=None, top_n=10):
# Country -> Emoji mapping
country_emoji_map = {
"United States of America": "🇺🇸",
"China": "🇨🇳",
"Germany": "🇩🇪",
"France": "🇫🇷",
"India": "🇮🇳",
"Italy": "🇮🇹",
"Japan": "🇯🇵",
"South Korea": "🇰🇷",
"United Kingdom": "🇬🇧",
"Canada": "🇨🇦",
"Brazil": "🇧🇷",
"Australia": "🇦🇺",
"Unknown": "❓",
"Finland": "🇫🇮",
"Lebanon": "🇱🇧 ",
}
# Ensure datetime
country_df["time"] = pd.to_datetime(country_df["time"])
developer_df["time"] = pd.to_datetime(developer_df["time"])
model_df["time"] = pd.to_datetime(model_df["time"])
# Add corresponding country info to developer_df and model_df, mapping "metric" to "author" and "metric" to "model"
# Merge with filtered_df to get country info
developer_df = developer_df.merge(
filtered_df[["author", "country"]].drop_duplicates(),
left_on="metric",
right_on="author",
how="left"
).rename(columns={"country": "country_metric"}).drop(columns=["author"])
model_df = model_df.merge(
filtered_df[["model", "country"]].drop_duplicates(),
left_on="metric",
right_on="model",
how="left"
).rename(columns={"country": "country_metric"}).drop(columns=["model"])
if start_time is None:
start_time = country_df["time"].min()
if end_time is None:
end_time = country_df["time"].max()
# Filter time range
country_df_filtered = country_df[
(country_df["time"] >= start_time) & (country_df["time"] <= end_time)
]
developer_df_filtered = developer_df[
(developer_df["time"] >= start_time) & (developer_df["time"] <= end_time)
]
model_df_filtered = model_df[
(model_df["time"] >= start_time) & (model_df["time"] <= end_time)
]
if country_df_filtered.empty and developer_df_filtered.empty and model_df_filtered.empty:
return go.Figure()
# Function to get top N leaderboard with percentage
def get_top_n_leaderboard(df, group_col, label, top_n=10):
top = (
df.groupby(group_col)["value"]
.sum()
.sort_values(ascending=False)
.head(top_n)
.reset_index()
.rename(columns={group_col: label, "value": "Total Value"})
)
total_value = top["Total Value"].sum()
if total_value > 0:
top["% of total"] = top["Total Value"] / total_value * 100
else:
top["% of total"] = 0
# add column with metadata (country emoji for country, country for developer/model)
if label == "Country":
top["Attributes"] = top[label].map(country_emoji_map).fillna("")
else:
# Get the country_metric for each developer/model with the already merged info
top = top.merge(
df[[group_col, "country_metric"]].drop_duplicates(),
left_on=label,
right_on=group_col,
how="left"
).drop(columns=[group_col])
top["Attributes"] = top["country_metric"].map(country_emoji_map).fillna("")
return top[[label, "Attributes", "% of total"]]
top_countries = get_top_n_leaderboard(country_df_filtered, "metric", "Country", top_n=top_n)
top_developers = get_top_n_leaderboard(developer_df_filtered, "metric", "Developer", top_n=top_n)
top_models = get_top_n_leaderboard(model_df_filtered, "metric", "Model", top_n=top_n)
# Create subplot grid with 3 columns
fig = make_subplots(
rows=1, cols=3,
subplot_titles=("Top Countries", "Top Developers", "Top Models"),
specs=[[{"type": "table"}, {"type": "table"}, {"type": "table"}]]
)
# Add country table
fig.add_trace(
go.Table(
header=dict(values=list(top_countries.columns),
fill_color="lightgrey", align="left"),
cells=dict(values=[top_countries[col] for col in top_countries.columns],
fill_color="white", align="left"),
),
row=1, col=1
)
# Add developer table
fig.add_trace(
go.Table(
header=dict(values=list(top_developers.columns),
fill_color="lightgrey", align="left"),
cells=dict(values=[top_developers[col] for col in top_developers.columns],
fill_color="white", align="left"),
),
row=1, col=2
)
# Add model table
fig.add_trace(
go.Table(
header=dict(values=list(top_models.columns),
fill_color="lightgrey", align="left"),
cells=dict(values=[top_models[col] for col in top_models.columns],
fill_color="white", align="left"),
),
row=1, col=3
)
fig.update_layout(
height=400,
showlegend=False,
title_text="Leaderboards"
)
return fig