subiendo la demo a hf
Browse files- app.py +4 -2
- src/__pycache__/model_LN_prompt.cpython-310.pyc +0 -0
- src/model_LN_prompt.py +3 -6
app.py
CHANGED
|
@@ -5,6 +5,7 @@ from multiprocessing.dummy import Pool
|
|
| 5 |
import base64
|
| 6 |
from PIL import Image, ImageOps
|
| 7 |
import torch
|
|
|
|
| 8 |
from torchvision import transforms
|
| 9 |
from streamlit_drawable_canvas import st_canvas
|
| 10 |
from src.model_LN_prompt import Model
|
|
@@ -83,8 +84,9 @@ def compute_sketch(_sketch, model):
|
|
| 83 |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
| 84 |
query_embedding = compute_sketch(_query, model)
|
| 85 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
| 86 |
-
image_features = torch.
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
dot_product = (image_features @ query_embedding.T)[:, 0]
|
| 90 |
_, max_indices = torch.topk(
|
|
|
|
| 5 |
import base64
|
| 6 |
from PIL import Image, ImageOps
|
| 7 |
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
from torchvision import transforms
|
| 10 |
from streamlit_drawable_canvas import st_canvas
|
| 11 |
from src.model_LN_prompt import Model
|
|
|
|
| 84 |
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
| 85 |
query_embedding = compute_sketch(_query, model)
|
| 86 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
| 87 |
+
image_features = torch.from_numpy(
|
| 88 |
+
np.array([item[0] for item in embeddings[corpus_id]])
|
| 89 |
+
).to(device)
|
| 90 |
|
| 91 |
dot_product = (image_features @ query_embedding.T)[:, 0]
|
| 92 |
_, max_indices = torch.topk(
|
src/__pycache__/model_LN_prompt.cpython-310.pyc
CHANGED
|
Binary files a/src/__pycache__/model_LN_prompt.cpython-310.pyc and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ
|
|
|
src/model_LN_prompt.py
CHANGED
|
@@ -32,14 +32,11 @@ class Model(pl.LightningModule):
|
|
| 32 |
|
| 33 |
|
| 34 |
def configure_optimizers(self):
|
| 35 |
-
|
| 36 |
-
model_params = list(self.dino.parameters())
|
| 37 |
-
else:
|
| 38 |
-
model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters())
|
| 39 |
|
| 40 |
optimizer = torch.optim.Adam([
|
| 41 |
-
{'params': model_params, 'lr': self.opts.clip_LN_lr}
|
| 42 |
-
|
| 43 |
return optimizer
|
| 44 |
|
| 45 |
def forward(self, data, dtype='image'):
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def configure_optimizers(self):
|
| 35 |
+
model_params = list(self.dino.parameters())
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
optimizer = torch.optim.Adam([
|
| 38 |
+
{'params': model_params, 'lr': self.opts.clip_LN_lr}]
|
| 39 |
+
)
|
| 40 |
return optimizer
|
| 41 |
|
| 42 |
def forward(self, data, dtype='image'):
|