Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
72c9292
1
Parent(s):
1ef69ae
refactor model data loading to handle optional 'param_count' column and improve logging
Browse files
main.py
CHANGED
|
@@ -193,24 +193,38 @@ def setup_database():
|
|
| 193 |
)
|
| 194 |
|
| 195 |
# Load model data
|
| 196 |
-
|
| 197 |
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
|
| 198 |
)
|
| 199 |
-
model_row_count =
|
| 200 |
logger.info(f"Row count of new model data: {model_row_count}")
|
| 201 |
|
| 202 |
if model_collection.count() < model_row_count:
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
total_rows = len(model_df)
|
| 215 |
|
| 216 |
for i in range(0, total_rows, BATCH_SIZE):
|
|
|
|
| 193 |
)
|
| 194 |
|
| 195 |
# Load model data
|
| 196 |
+
model_lazy_df = pl.scan_parquet(
|
| 197 |
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
|
| 198 |
)
|
| 199 |
+
model_row_count = model_lazy_df.select(pl.len()).collect().item()
|
| 200 |
logger.info(f"Row count of new model data: {model_row_count}")
|
| 201 |
|
| 202 |
if model_collection.count() < model_row_count:
|
| 203 |
+
schema = model_lazy_df.schema
|
| 204 |
+
select_columns = [
|
| 205 |
+
"modelId",
|
| 206 |
+
"summary",
|
| 207 |
+
"likes",
|
| 208 |
+
"downloads",
|
| 209 |
+
"last_modified",
|
| 210 |
+
]
|
| 211 |
+
if "param_count" in schema:
|
| 212 |
+
logger.info("Found 'param_count' column in model data schema.")
|
| 213 |
+
select_columns.append("param_count")
|
| 214 |
+
else:
|
| 215 |
+
logger.warning(
|
| 216 |
+
"'param_count' column not found in model data schema. Will add it with null values."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Select specified columns and then collect
|
| 220 |
+
model_df = model_lazy_df.select(select_columns).collect()
|
| 221 |
+
|
| 222 |
+
# If param_count was not in the original schema, add it now to the collected DataFrame
|
| 223 |
+
if "param_count" not in model_df.columns:
|
| 224 |
+
model_df = model_df.with_columns(
|
| 225 |
+
pl.lit(None).cast(pl.Int64).alias("param_count")
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
total_rows = len(model_df)
|
| 229 |
|
| 230 |
for i in range(0, total_rows, BATCH_SIZE):
|