Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
ac80dbe
1
Parent(s):
42e0474
add parameter count filters to model search and trending endpoints
Browse files
main.py
CHANGED
|
@@ -366,29 +366,64 @@ async def find_similar_datasets(
|
|
| 366 |
@cache(ttl=CACHE_TTL)
|
| 367 |
async def search_models(
|
| 368 |
query: str,
|
| 369 |
-
k: int = Query(default=5, ge=1, le=100),
|
| 370 |
sort_by: str = Query(
|
| 371 |
-
default="similarity",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
),
|
| 373 |
-
min_likes: int = Query(default=0, ge=0),
|
| 374 |
-
min_downloads: int = Query(default=0, ge=0),
|
| 375 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
try:
|
| 377 |
collection = client.get_collection(
|
| 378 |
name="model_cards", embedding_function=get_embedding_function()
|
| 379 |
)
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
results = collection.query(
|
| 382 |
query_texts=[f"search_query: {query}"],
|
| 383 |
n_results=k * 4 if sort_by != "similarity" else k,
|
| 384 |
-
where=
|
| 385 |
-
"$and": [
|
| 386 |
-
{"likes": {"$gte": min_likes}},
|
| 387 |
-
{"downloads": {"$gte": min_downloads}},
|
| 388 |
-
]
|
| 389 |
-
}
|
| 390 |
-
if min_likes > 0 or min_downloads > 0
|
| 391 |
-
else None,
|
| 392 |
)
|
| 393 |
|
| 394 |
query_results = await process_search_results(results, "model", k, sort_by)
|
|
@@ -404,13 +439,31 @@ async def search_models(
|
|
| 404 |
@cache(ttl=CACHE_TTL)
|
| 405 |
async def find_similar_models(
|
| 406 |
model_id: str,
|
| 407 |
-
k: int = Query(default=5, ge=1, le=100),
|
| 408 |
sort_by: str = Query(
|
| 409 |
-
default="similarity",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
),
|
| 411 |
-
min_likes: int = Query(default=0, ge=0),
|
| 412 |
-
min_downloads: int = Query(default=0, ge=0),
|
| 413 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
try:
|
| 415 |
collection = client.get_collection("model_cards")
|
| 416 |
|
|
@@ -421,17 +474,34 @@ async def find_similar_models(
|
|
| 421 |
status_code=404, detail=f"Model ID '{model_id}' not found"
|
| 422 |
)
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
results = collection.query(
|
| 425 |
query_embeddings=[results["embeddings"][0]],
|
| 426 |
n_results=k * 4 if sort_by != "similarity" else k + 1,
|
| 427 |
-
where=
|
| 428 |
-
"$and": [
|
| 429 |
-
{"likes": {"$gte": min_likes}},
|
| 430 |
-
{"downloads": {"$gte": min_downloads}},
|
| 431 |
-
]
|
| 432 |
-
}
|
| 433 |
-
if min_likes > 0 or min_downloads > 0
|
| 434 |
-
else None,
|
| 435 |
)
|
| 436 |
|
| 437 |
query_results = await process_search_results(
|
|
@@ -538,6 +608,8 @@ async def get_trending_models_with_summaries(
|
|
| 538 |
limit: int = 10,
|
| 539 |
min_likes: int = 0,
|
| 540 |
min_downloads: int = 0,
|
|
|
|
|
|
|
| 541 |
) -> List[ModelQueryResult]:
|
| 542 |
"""Fetch trending models and combine with summaries from database"""
|
| 543 |
try:
|
|
@@ -573,13 +645,30 @@ async def get_trending_models_with_summaries(
|
|
| 573 |
for model in trending_models:
|
| 574 |
if model["modelId"] in id_to_summary:
|
| 575 |
metadata = id_to_metadata.get(model["modelId"], {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
result = ModelQueryResult(
|
| 577 |
model_id=model["modelId"],
|
| 578 |
similarity=1.0, # Not applicable for trending
|
| 579 |
summary=id_to_summary[model["modelId"]],
|
| 580 |
likes=model.get("likes", 0),
|
| 581 |
downloads=model.get("downloads", 0),
|
| 582 |
-
param_count=
|
| 583 |
)
|
| 584 |
results.append(result)
|
| 585 |
|
|
@@ -592,13 +681,34 @@ async def get_trending_models_with_summaries(
|
|
| 592 |
|
| 593 |
@app.get("/trending/models", response_model=ModelQueryResponse)
|
| 594 |
async def get_trending_models(
|
| 595 |
-
limit: int = Query(
|
| 596 |
-
|
| 597 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
):
|
| 599 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
results = await get_trending_models_with_summaries(
|
| 601 |
-
limit=limit,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
)
|
| 603 |
return ModelQueryResponse(results=results)
|
| 604 |
|
|
|
|
| 366 |
@cache(ttl=CACHE_TTL)
|
| 367 |
async def search_models(
|
| 368 |
query: str,
|
| 369 |
+
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
|
| 370 |
sort_by: str = Query(
|
| 371 |
+
default="similarity",
|
| 372 |
+
enum=["similarity", "likes", "downloads", "trending"],
|
| 373 |
+
description="Sort method for results",
|
| 374 |
+
),
|
| 375 |
+
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
|
| 376 |
+
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
|
| 377 |
+
min_param_count: int = Query(
|
| 378 |
+
default=0,
|
| 379 |
+
ge=0,
|
| 380 |
+
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
|
| 381 |
+
),
|
| 382 |
+
max_param_count: Optional[int] = Query(
|
| 383 |
+
default=None,
|
| 384 |
+
ge=0,
|
| 385 |
+
description="Maximum parameter count (None means no upper limit)",
|
| 386 |
),
|
|
|
|
|
|
|
| 387 |
):
|
| 388 |
+
"""
|
| 389 |
+
Search for models based on a text query with optional filtering.
|
| 390 |
+
|
| 391 |
+
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
|
| 392 |
+
- param_count=0 indicates missing/unknown parameter count in the dataset
|
| 393 |
+
"""
|
| 394 |
try:
|
| 395 |
collection = client.get_collection(
|
| 396 |
name="model_cards", embedding_function=get_embedding_function()
|
| 397 |
)
|
| 398 |
|
| 399 |
+
where_conditions = []
|
| 400 |
+
if min_likes > 0:
|
| 401 |
+
where_conditions.append({"likes": {"$gte": min_likes}})
|
| 402 |
+
if min_downloads > 0:
|
| 403 |
+
where_conditions.append({"downloads": {"$gte": min_downloads}})
|
| 404 |
+
|
| 405 |
+
# Add parameter count filters
|
| 406 |
+
using_param_filters = min_param_count > 0 or max_param_count is not None
|
| 407 |
+
if using_param_filters:
|
| 408 |
+
# Always exclude zero param count when using any parameter filters
|
| 409 |
+
where_conditions.append({"param_count": {"$gt": 0}})
|
| 410 |
+
|
| 411 |
+
if min_param_count > 0:
|
| 412 |
+
where_conditions.append({"param_count": {"$gte": min_param_count}})
|
| 413 |
+
if max_param_count is not None:
|
| 414 |
+
where_conditions.append({"param_count": {"$lte": max_param_count}})
|
| 415 |
+
|
| 416 |
+
# Handle where clause creation based on number of conditions
|
| 417 |
+
where_clause = None
|
| 418 |
+
if len(where_conditions) > 1:
|
| 419 |
+
where_clause = {"$and": where_conditions}
|
| 420 |
+
elif len(where_conditions) == 1:
|
| 421 |
+
where_clause = where_conditions[0] # Single condition without $and
|
| 422 |
+
|
| 423 |
results = collection.query(
|
| 424 |
query_texts=[f"search_query: {query}"],
|
| 425 |
n_results=k * 4 if sort_by != "similarity" else k,
|
| 426 |
+
where=where_clause,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
)
|
| 428 |
|
| 429 |
query_results = await process_search_results(results, "model", k, sort_by)
|
|
|
|
| 439 |
@cache(ttl=CACHE_TTL)
|
| 440 |
async def find_similar_models(
|
| 441 |
model_id: str,
|
| 442 |
+
k: int = Query(default=5, ge=1, le=100, description="Number of results to return"),
|
| 443 |
sort_by: str = Query(
|
| 444 |
+
default="similarity",
|
| 445 |
+
enum=["similarity", "likes", "downloads", "trending"],
|
| 446 |
+
description="Sort method for results",
|
| 447 |
+
),
|
| 448 |
+
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
|
| 449 |
+
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
|
| 450 |
+
min_param_count: int = Query(
|
| 451 |
+
default=0,
|
| 452 |
+
ge=0,
|
| 453 |
+
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
|
| 454 |
+
),
|
| 455 |
+
max_param_count: Optional[int] = Query(
|
| 456 |
+
default=None,
|
| 457 |
+
ge=0,
|
| 458 |
+
description="Maximum parameter count (None means no upper limit)",
|
| 459 |
),
|
|
|
|
|
|
|
| 460 |
):
|
| 461 |
+
"""
|
| 462 |
+
Find similar models to a specified model with optional filtering.
|
| 463 |
+
|
| 464 |
+
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
|
| 465 |
+
- param_count=0 indicates missing/unknown parameter count in the dataset
|
| 466 |
+
"""
|
| 467 |
try:
|
| 468 |
collection = client.get_collection("model_cards")
|
| 469 |
|
|
|
|
| 474 |
status_code=404, detail=f"Model ID '{model_id}' not found"
|
| 475 |
)
|
| 476 |
|
| 477 |
+
where_conditions = []
|
| 478 |
+
if min_likes > 0:
|
| 479 |
+
where_conditions.append({"likes": {"$gte": min_likes}})
|
| 480 |
+
if min_downloads > 0:
|
| 481 |
+
where_conditions.append({"downloads": {"$gte": min_downloads}})
|
| 482 |
+
|
| 483 |
+
# Add parameter count filters
|
| 484 |
+
using_param_filters = min_param_count > 0 or max_param_count is not None
|
| 485 |
+
if using_param_filters:
|
| 486 |
+
# Always exclude zero param count when using any parameter filters
|
| 487 |
+
where_conditions.append({"param_count": {"$gt": 0}})
|
| 488 |
+
|
| 489 |
+
if min_param_count > 0:
|
| 490 |
+
where_conditions.append({"param_count": {"$gte": min_param_count}})
|
| 491 |
+
if max_param_count is not None:
|
| 492 |
+
where_conditions.append({"param_count": {"$lte": max_param_count}})
|
| 493 |
+
|
| 494 |
+
# Handle where clause creation based on number of conditions
|
| 495 |
+
where_clause = None
|
| 496 |
+
if len(where_conditions) > 1:
|
| 497 |
+
where_clause = {"$and": where_conditions}
|
| 498 |
+
elif len(where_conditions) == 1:
|
| 499 |
+
where_clause = where_conditions[0] # Single condition without $and
|
| 500 |
+
|
| 501 |
results = collection.query(
|
| 502 |
query_embeddings=[results["embeddings"][0]],
|
| 503 |
n_results=k * 4 if sort_by != "similarity" else k + 1,
|
| 504 |
+
where=where_clause,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
)
|
| 506 |
|
| 507 |
query_results = await process_search_results(
|
|
|
|
| 608 |
limit: int = 10,
|
| 609 |
min_likes: int = 0,
|
| 610 |
min_downloads: int = 0,
|
| 611 |
+
min_param_count: int = 0,
|
| 612 |
+
max_param_count: Optional[int] = None,
|
| 613 |
) -> List[ModelQueryResult]:
|
| 614 |
"""Fetch trending models and combine with summaries from database"""
|
| 615 |
try:
|
|
|
|
| 645 |
for model in trending_models:
|
| 646 |
if model["modelId"] in id_to_summary:
|
| 647 |
metadata = id_to_metadata.get(model["modelId"], {})
|
| 648 |
+
param_count = metadata.get("param_count", 0)
|
| 649 |
+
|
| 650 |
+
# Apply parameter count filters
|
| 651 |
+
using_param_filters = min_param_count > 0 or max_param_count is not None
|
| 652 |
+
|
| 653 |
+
# Skip if param_count is 0 and we're using param filters
|
| 654 |
+
if using_param_filters and param_count == 0:
|
| 655 |
+
continue
|
| 656 |
+
|
| 657 |
+
# Skip if param_count is less than min_param_count
|
| 658 |
+
if min_param_count > 0 and param_count < min_param_count:
|
| 659 |
+
continue
|
| 660 |
+
|
| 661 |
+
# Skip if param_count is greater than max_param_count
|
| 662 |
+
if max_param_count is not None and param_count > max_param_count:
|
| 663 |
+
continue
|
| 664 |
+
|
| 665 |
result = ModelQueryResult(
|
| 666 |
model_id=model["modelId"],
|
| 667 |
similarity=1.0, # Not applicable for trending
|
| 668 |
summary=id_to_summary[model["modelId"]],
|
| 669 |
likes=model.get("likes", 0),
|
| 670 |
downloads=model.get("downloads", 0),
|
| 671 |
+
param_count=param_count,
|
| 672 |
)
|
| 673 |
results.append(result)
|
| 674 |
|
|
|
|
| 681 |
|
| 682 |
@app.get("/trending/models", response_model=ModelQueryResponse)
|
| 683 |
async def get_trending_models(
|
| 684 |
+
limit: int = Query(
|
| 685 |
+
default=10, ge=1, le=100, description="Number of results to return"
|
| 686 |
+
),
|
| 687 |
+
min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"),
|
| 688 |
+
min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"),
|
| 689 |
+
min_param_count: int = Query(
|
| 690 |
+
default=0,
|
| 691 |
+
ge=0,
|
| 692 |
+
description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)",
|
| 693 |
+
),
|
| 694 |
+
max_param_count: Optional[int] = Query(
|
| 695 |
+
default=None,
|
| 696 |
+
ge=0,
|
| 697 |
+
description="Maximum parameter count (None means no upper limit)",
|
| 698 |
+
),
|
| 699 |
):
|
| 700 |
+
"""
|
| 701 |
+
Get trending models with their summaries and optional filtering.
|
| 702 |
+
|
| 703 |
+
- When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded
|
| 704 |
+
- param_count=0 indicates missing/unknown parameter count in the dataset
|
| 705 |
+
"""
|
| 706 |
results = await get_trending_models_with_summaries(
|
| 707 |
+
limit=limit,
|
| 708 |
+
min_likes=min_likes,
|
| 709 |
+
min_downloads=min_downloads,
|
| 710 |
+
min_param_count=min_param_count,
|
| 711 |
+
max_param_count=max_param_count,
|
| 712 |
)
|
| 713 |
return ModelQueryResponse(results=results)
|
| 714 |
|