Spaces:
Sleeping
Sleeping
notbulubula
commited on
Commit
·
1638e8f
1
Parent(s):
9b37c86
naprawianie all
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
| 4 |
import os
|
| 5 |
# import matplotlib.pyplot as plt
|
| 6 |
|
| 7 |
-
from utils import fetch_runs_to_df, fetch_run
|
| 8 |
|
| 9 |
# Access the API key from the environment variable
|
| 10 |
wandb_api_key = os.getenv('WANDB_API_KEY')
|
|
@@ -84,27 +84,8 @@ if option == "Models":
|
|
| 84 |
|
| 85 |
# Ensure the DataFrame is not empty
|
| 86 |
if not df.empty:
|
| 87 |
-
# Fetch metrics for ranking (e.g., accuracy or loss)
|
| 88 |
-
ranking_data = []
|
| 89 |
-
for index, row in df.iterrows():
|
| 90 |
-
try:
|
| 91 |
-
# Fetch the run details
|
| 92 |
-
run = api.run(f"{projects[selected_project]['entity']}/{projects[selected_project]['project']}/{row['ID']}")
|
| 93 |
-
metrics = run.summary
|
| 94 |
-
model_name = run.config.get("model_name", "Unknown") # Fetch model name from the config, defaulting to "Unknown"
|
| 95 |
-
|
| 96 |
-
ranking_data.append({
|
| 97 |
-
"Model Name": model_name, # Add model name to the table
|
| 98 |
-
"Run Name": row["Run Name"],
|
| 99 |
-
"ID": row["ID"],
|
| 100 |
-
"Accuracy": metrics.get("accuracy"), # Example metric
|
| 101 |
-
"Loss": metrics.get("loss") # Example metric
|
| 102 |
-
})
|
| 103 |
-
except wandb.errors.CommError:
|
| 104 |
-
continue
|
| 105 |
-
|
| 106 |
# Convert to DataFrame
|
| 107 |
-
ranking_df =
|
| 108 |
|
| 109 |
# Rank by Accuracy (or another metric)
|
| 110 |
ranking_df = ranking_df.sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
|
|
|
|
| 4 |
import os
|
| 5 |
# import matplotlib.pyplot as plt
|
| 6 |
|
| 7 |
+
from utils import fetch_runs_to_df, fetch_run, fetch_models_to_df
|
| 8 |
|
| 9 |
# Access the API key from the environment variable
|
| 10 |
wandb_api_key = os.getenv('WANDB_API_KEY')
|
|
|
|
| 84 |
|
| 85 |
# Ensure the DataFrame is not empty
|
| 86 |
if not df.empty:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# Convert to DataFrame
|
| 88 |
+
ranking_df = fetch_models_to_df(api, projects, selected_project, df)
|
| 89 |
|
| 90 |
# Rank by Accuracy (or another metric)
|
| 91 |
ranking_df = ranking_df.sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
|
utils.py
CHANGED
|
@@ -57,4 +57,43 @@ def fetch_run(api, projects, selected_project, selected_run_id):
|
|
| 57 |
project = projects[selected_project]["project"]
|
| 58 |
run = api.run(f"{entity}/{project}/{selected_run_id}")
|
| 59 |
|
| 60 |
-
return run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
project = projects[selected_project]["project"]
|
| 58 |
run = api.run(f"{entity}/{project}/{selected_run_id}")
|
| 59 |
|
| 60 |
+
return run
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def fetch_models_to_df(api, projects, selected_project, df):
|
| 64 |
+
data = []
|
| 65 |
+
for index, row in df.iterrows():
|
| 66 |
+
try:
|
| 67 |
+
if selected_project == "All":
|
| 68 |
+
# Determine the project for the current run
|
| 69 |
+
for project_name, details in projects.items():
|
| 70 |
+
entity = details["entity"]
|
| 71 |
+
project = details["project"]
|
| 72 |
+
try:
|
| 73 |
+
run = api.run(f"{entity}/{project}/{row['ID']}")
|
| 74 |
+
break
|
| 75 |
+
except wandb.errors.CommError:
|
| 76 |
+
continue
|
| 77 |
+
else:
|
| 78 |
+
st.error(f"Run ID {row['ID']} not found in any project.")
|
| 79 |
+
continue
|
| 80 |
+
else:
|
| 81 |
+
entity = projects[selected_project]["entity"]
|
| 82 |
+
project = projects[selected_project]["project"]
|
| 83 |
+
run = api.run(f"{entity}/{project}/{row['ID']}")
|
| 84 |
+
|
| 85 |
+
metrics = run.summary
|
| 86 |
+
model_name = run.config.get("model_name", "Unknown") # Fetch model name from the config, defaulting to "Unknown"
|
| 87 |
+
|
| 88 |
+
data.append({
|
| 89 |
+
"Model Name": model_name, # Add model name to the table
|
| 90 |
+
"Run Name": row["Run Name"],
|
| 91 |
+
"ID": row["ID"],
|
| 92 |
+
"Accuracy": metrics.get("accuracy"), # Example metric
|
| 93 |
+
"Loss": metrics.get("loss") # Example metric
|
| 94 |
+
})
|
| 95 |
+
except wandb.errors.CommError:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
data_df = pd.DataFrame(data)
|
| 99 |
+
return data_df
|