Spaces:
Runtime error
Runtime error
Update clip_model.py
Browse files- clip_model.py +3 -3
clip_model.py
CHANGED
|
@@ -306,7 +306,7 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
|
|
| 306 |
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
|
| 307 |
matches = [image_filenames[idx] for idx in indices[::5]]
|
| 308 |
|
| 309 |
-
_, axes = plt.subplots(
|
| 310 |
|
| 311 |
results = []
|
| 312 |
for match, ax in zip(matches, axes.flatten()):
|
|
@@ -321,11 +321,11 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
|
|
| 321 |
def clip_image_search(model,image_embeddings,
|
| 322 |
query,
|
| 323 |
image_filenames,
|
| 324 |
-
n=
|
| 325 |
_, valid_df = make_train_valid_dfs()
|
| 326 |
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
|
| 327 |
return find_matches(model,
|
| 328 |
image_embeddings,
|
| 329 |
query,
|
| 330 |
image_filenames = valid_df['image'].values,
|
| 331 |
-
n
|
|
|
|
| 306 |
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
|
| 307 |
matches = [image_filenames[idx] for idx in indices[::5]]
|
| 308 |
|
| 309 |
+
_, axes = plt.subplots(4, 4, figsize=(10, 10))
|
| 310 |
|
| 311 |
results = []
|
| 312 |
for match, ax in zip(matches, axes.flatten()):
|
|
|
|
| 321 |
def clip_image_search(model,image_embeddings,
|
| 322 |
query,
|
| 323 |
image_filenames,
|
| 324 |
+
n=16):
|
| 325 |
_, valid_df = make_train_valid_dfs()
|
| 326 |
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
|
| 327 |
return find_matches(model,
|
| 328 |
image_embeddings,
|
| 329 |
query,
|
| 330 |
image_filenames = valid_df['image'].values,
|
| 331 |
+
n)
|