Spaces:
Running
Running
Commit
·
6d540bf
1
Parent(s):
f276a79
update
Browse files
app.py
CHANGED
|
@@ -8,6 +8,9 @@ from huggingface_hub.utils._errors import EntryNotFoundError, RepositoryNotFound
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
from matplotlib.colors import LinearSegmentedColormap
|
| 10 |
import plotly.express as px
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
load_dotenv()
|
| 13 |
webhook_url = os.environ.get("WEBHOOK_URL")
|
|
@@ -271,6 +274,29 @@ for folder in get_folders_matching_format('data'):
|
|
| 271 |
pd.read_excel(final_file_name + '.xlsx', sheet_name=sheet_name))
|
| 272 |
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
def create_scaling_plot(all_data, period):
|
| 275 |
selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
|
| 276 |
target_data = all_data[period]
|
|
@@ -284,12 +310,36 @@ def create_scaling_plot(all_data, period):
|
|
| 284 |
'Average (The lower the better)': 'Compression Rate (%)'
|
| 285 |
}, inplace=True)
|
| 286 |
|
|
|
|
| 287 |
fig = px.scatter(new_df,
|
| 288 |
x='Params(B)',
|
| 289 |
y='Compression Rate (%)',
|
| 290 |
title='Compression Rate Scaling Law',
|
| 291 |
hover_name='Name'
|
| 292 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
fig.update_traces(marker=dict(size=12))
|
| 294 |
return fig
|
| 295 |
|
|
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
from matplotlib.colors import LinearSegmentedColormap
|
| 10 |
import plotly.express as px
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
from sklearn.linear_model import LinearRegression
|
| 13 |
+
import numpy as np
|
| 14 |
|
| 15 |
load_dotenv()
|
| 16 |
webhook_url = os.environ.get("WEBHOOK_URL")
|
|
|
|
| 274 |
pd.read_excel(final_file_name + '.xlsx', sheet_name=sheet_name))
|
| 275 |
|
| 276 |
|
| 277 |
+
# def create_scaling_plot(all_data, period):
|
| 278 |
+
# selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
|
| 279 |
+
# target_data = all_data[period]
|
| 280 |
+
# new_df = pd.DataFrame()
|
| 281 |
+
#
|
| 282 |
+
# for size in target_data.keys():
|
| 283 |
+
# new_df = pd.concat([new_df, target_data[size]['cr'].loc[:, selected_columns]], axis=0)
|
| 284 |
+
#
|
| 285 |
+
# new_df.rename(columns={
|
| 286 |
+
# 'Parameters Count (B)': 'Params(B)',
|
| 287 |
+
# 'Average (The lower the better)': 'Compression Rate (%)'
|
| 288 |
+
# }, inplace=True)
|
| 289 |
+
#
|
| 290 |
+
# fig = px.scatter(new_df,
|
| 291 |
+
# x='Params(B)',
|
| 292 |
+
# y='Compression Rate (%)',
|
| 293 |
+
# title='Compression Rate Scaling Law',
|
| 294 |
+
# hover_name='Name'
|
| 295 |
+
# )
|
| 296 |
+
# fig.update_traces(marker=dict(size=12))
|
| 297 |
+
# return fig
|
| 298 |
+
|
| 299 |
+
|
| 300 |
def create_scaling_plot(all_data, period):
|
| 301 |
selected_columns = ['Name', 'Parameters Count (B)', 'Average (The lower the better)']
|
| 302 |
target_data = all_data[period]
|
|
|
|
| 310 |
'Average (The lower the better)': 'Compression Rate (%)'
|
| 311 |
}, inplace=True)
|
| 312 |
|
| 313 |
+
# Create scatter plot
|
| 314 |
fig = px.scatter(new_df,
|
| 315 |
x='Params(B)',
|
| 316 |
y='Compression Rate (%)',
|
| 317 |
title='Compression Rate Scaling Law',
|
| 318 |
hover_name='Name'
|
| 319 |
)
|
| 320 |
+
|
| 321 |
+
# Add logarithmic trendline
|
| 322 |
+
X = new_df[['Params(B)']].values
|
| 323 |
+
y = new_df['Compression Rate (%)'].values
|
| 324 |
+
|
| 325 |
+
# Perform log transformation on X
|
| 326 |
+
X_log = np.log(X)
|
| 327 |
+
|
| 328 |
+
model = LinearRegression()
|
| 329 |
+
model.fit(X_log, y)
|
| 330 |
+
|
| 331 |
+
# Create trendline data for plot
|
| 332 |
+
X_plot = np.linspace(X_log.min() - 1, X_log.max() + 0.1, 100)
|
| 333 |
+
y_plot = model.predict(X_plot.reshape(-1, 1))
|
| 334 |
+
|
| 335 |
+
fig.add_trace(go.Scatter(
|
| 336 |
+
x=np.exp(X_plot),
|
| 337 |
+
y=y_plot,
|
| 338 |
+
mode='lines',
|
| 339 |
+
name='Trend',
|
| 340 |
+
line=dict(color='#39C5BB')
|
| 341 |
+
))
|
| 342 |
+
|
| 343 |
fig.update_traces(marker=dict(size=12))
|
| 344 |
return fig
|
| 345 |
|