Spaces:
Running
on
Zero
Running
on
Zero
i cant use git for the life of me. might need more testing
#2
by
npbm
- opened
- .gitattributes +35 -35
- README.md +12 -12
- app.py +56 -47
- requirements.txt +8 -7
- utils/dataset_rag.py +64 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: RAG
|
| 3 |
-
emoji: ๐ฅ
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 4.37.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Image RAG
|
| 3 |
+
emoji: ๐ฅ
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.37.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,47 +1,56 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
return
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils import dataset_rag
|
| 3 |
+
|
| 4 |
+
dirty_hack = True
|
| 5 |
+
|
| 6 |
+
if dirty_hack:
|
| 7 |
+
import os
|
| 8 |
+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
datasets = [
|
| 12 |
+
"not-lain/embedded-pokemon"
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
space_installed = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import spaces
|
| 19 |
+
space_installed = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
space_installed = False
|
| 22 |
+
|
| 23 |
+
if space_installed:
|
| 24 |
+
@spaces.GPU
|
| 25 |
+
def instance(dataset_name):
|
| 26 |
+
return dataset_rag.Instance(dataset_name)
|
| 27 |
+
else:
|
| 28 |
+
def instance(dataset_name):
|
| 29 |
+
return dataset_rag.Instance(dataset_name)
|
| 30 |
+
|
| 31 |
+
def download(dataset):
|
| 32 |
+
global ds
|
| 33 |
+
client = instance(datasets[0])
|
| 34 |
+
ds = client
|
| 35 |
+
return client
|
| 36 |
+
|
| 37 |
+
def search_ds(image):
|
| 38 |
+
scores, retrieved_examples = ds.search(image)
|
| 39 |
+
return retrieved_examples, scores
|
| 40 |
+
|
| 41 |
+
with gr.Blocks(title="Image RAG") as demo:
|
| 42 |
+
ds = None
|
| 43 |
+
interactive_mode = False
|
| 44 |
+
dataset_name = gr.Dropdown(label="Dataset", choices=datasets, value=datasets[0])
|
| 45 |
+
download_dataset = gr.Button("Download Dataset")
|
| 46 |
+
|
| 47 |
+
search = gr.Image(label="Search Image")
|
| 48 |
+
search_button = gr.Button("Search")
|
| 49 |
+
results = gr.Gallery(label="Results")
|
| 50 |
+
scores = gr.Textbox(label="Scores", type="text", value="")
|
| 51 |
+
search_button.click(search_ds, inputs=[search], outputs=[results, scores])
|
| 52 |
+
|
| 53 |
+
download_dataset.click(download, dataset_name)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
| 1 |
+
datasets
|
| 2 |
+
accelerate
|
| 3 |
+
loadimg
|
| 4 |
+
faiss-cpu
|
| 5 |
+
numpy==1.26.0
|
| 6 |
+
transformers # hf spaces already have it installed.
|
| 7 |
+
pillow
|
| 8 |
+
gradio # duh
|
utils/dataset_rag.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
|
| 4 |
+
from loadimg import load_img
|
| 5 |
+
|
| 6 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol)
|
| 7 |
+
|
| 8 |
+
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 9 |
+
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
|
| 10 |
+
|
| 11 |
+
class Instance:
|
| 12 |
+
def __init__(self, dataset, token=None, split="train"):
|
| 13 |
+
self.dataset = dataset
|
| 14 |
+
self.token = token
|
| 15 |
+
self.split = split
|
| 16 |
+
self.data = load_dataset(self.dataset, split=self.split)
|
| 17 |
+
self.data = self.data.add_faiss_index("embeddings")
|
| 18 |
+
|
| 19 |
+
def embed(batch):
|
| 20 |
+
"""a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)"""
|
| 21 |
+
pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values']
|
| 22 |
+
pixel_values = pixel_values.to(device)
|
| 23 |
+
img_emb = model.get_image_features(pixel_values)
|
| 24 |
+
batch["embeddings"] = img_emb
|
| 25 |
+
return batch
|
| 26 |
+
|
| 27 |
+
def search(self, query: str, k: int = 3 ):
|
| 28 |
+
"""
|
| 29 |
+
A function that embeds a query image and returns the most probable results.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
query: the image to search for
|
| 33 |
+
k: the number of results to return
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
scores: the scores of the retrieved examples (cosine similarity i think in this case)
|
| 37 |
+
retrieved_examples: the retrieved examples
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
pixel_values = processor(images = query, return_tensors="pt")['pixel_values']
|
| 41 |
+
pixel_values = pixel_values.to(device)
|
| 42 |
+
img_emb = model.get_image_features(pixel_values)[0]
|
| 43 |
+
img_emb = img_emb.cpu().detach().numpy()
|
| 44 |
+
|
| 45 |
+
scores, retrieved_examples = self.data.get_nearest_examples(
|
| 46 |
+
"embeddings", img_emb,
|
| 47 |
+
k=k
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return scores, retrieved_examples
|
| 51 |
+
|
| 52 |
+
def high_level_search(self, img):
|
| 53 |
+
"""
|
| 54 |
+
High level wrapper for the search function.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
img: input image (path, url, pillow or numpy)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
scores: the scores of the retrieved examples (cosine similarity i think in this case)
|
| 61 |
+
retrieved_examples: the retrieved examples
|
| 62 |
+
"""
|
| 63 |
+
image = load_img(img)
|
| 64 |
+
scores, retrieved_examples = self.search(image)
|