Commit
·
dd2978a
1
Parent(s):
a93f0c7
add ragatouille search
Browse files- app.py +66 -54
- ragatouille_search.py +107 -0
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import gradio as gr
|
|
| 7 |
import httpx
|
| 8 |
from cashews import cache
|
| 9 |
from huggingface_hub import ModelCard
|
|
|
|
| 10 |
|
| 11 |
cache.setup("mem://")
|
| 12 |
API_URL = "https://davanstrien-huggingface-datasets-search-v2.hf.space/similar"
|
|
@@ -150,61 +151,72 @@ async def search_similar_datasets(dataset_id: str, limit: int = 10):
|
|
| 150 |
|
| 151 |
|
| 152 |
with gr.Blocks() as demo:
|
| 153 |
-
gr.Markdown("## 🤗 Dataset Similarity
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
"
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
search_similar_datasets(dataset_id, limit)
|
| 203 |
-
if search_type == "Dataset ID"
|
| 204 |
-
else search_similar_datasets_by_text(text_query, limit)
|
| 205 |
-
),
|
| 206 |
-
inputs=[search_type, dataset_id, text_query, max_results],
|
| 207 |
-
outputs=results,
|
| 208 |
-
)
|
| 209 |
|
| 210 |
demo.launch()
|
|
|
|
| 7 |
import httpx
|
| 8 |
from cashews import cache
|
| 9 |
from huggingface_hub import ModelCard
|
| 10 |
+
from ragatouille_search import create_ragatouille_interface, search_with_ragatouille
|
| 11 |
|
| 12 |
cache.setup("mem://")
|
| 13 |
API_URL = "https://davanstrien-huggingface-datasets-search-v2.hf.space/similar"
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
with gr.Blocks() as demo:
|
| 154 |
+
gr.Markdown("## 🤗 Dataset Search and Similarity")
|
| 155 |
+
|
| 156 |
+
with gr.Tabs():
|
| 157 |
+
with gr.TabItem("Similar Datasets"):
|
| 158 |
+
gr.Markdown("## 🤗 Dataset Similarity Search")
|
| 159 |
+
with gr.Row():
|
| 160 |
+
gr.Markdown(
|
| 161 |
+
"This Gradio app allows you to find similar datasets based on a given dataset ID or a text query. "
|
| 162 |
+
"Choose the search type and enter either a dataset ID or a text query to find similar datasets with previews of their dataset cards.\n\n"
|
| 163 |
+
"For a seamless experience on the Hugging Face website, check out the "
|
| 164 |
+
"[Hugging Face Similar Chrome extension](https://chromewebstore.google.com/detail/hugging-face-similar/aijelnjllajooinkcpkpbhckbghghpnl?authuser=0&hl=en). "
|
| 165 |
+
"This extension adds a 'Similar Datasets' section directly to Hugging Face dataset pages, "
|
| 166 |
+
"making it even easier to discover related datasets for your projects."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
with gr.Row():
|
| 170 |
+
search_type = gr.Radio(
|
| 171 |
+
["Dataset ID", "Text Query"],
|
| 172 |
+
label="Search Type",
|
| 173 |
+
value="Dataset ID",
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
with gr.Row():
|
| 177 |
+
dataset_id = gr.Textbox(
|
| 178 |
+
value="airtrain-ai/fineweb-edu-fortified",
|
| 179 |
+
label="Dataset ID (e.g., airtrain-ai/fineweb-edu-fortified)",
|
| 180 |
+
)
|
| 181 |
+
text_query = gr.Textbox(
|
| 182 |
+
label="Text Query (e.g., 'natural language processing dataset')",
|
| 183 |
+
visible=False,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
with gr.Row():
|
| 187 |
+
search_btn = gr.Button("Search Similar Datasets")
|
| 188 |
+
max_results = gr.Slider(
|
| 189 |
+
minimum=1,
|
| 190 |
+
maximum=50,
|
| 191 |
+
step=1,
|
| 192 |
+
value=10,
|
| 193 |
+
label="Maximum number of results",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
results = gr.Markdown()
|
| 197 |
+
|
| 198 |
+
def toggle_input_visibility(choice):
|
| 199 |
+
return gr.update(visible=choice == "Dataset ID"), gr.update(
|
| 200 |
+
visible=choice == "Text Query"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
search_type.change(
|
| 204 |
+
toggle_input_visibility,
|
| 205 |
+
inputs=[search_type],
|
| 206 |
+
outputs=[dataset_id, text_query],
|
| 207 |
+
)
|
| 208 |
|
| 209 |
+
search_btn.click(
|
| 210 |
+
lambda search_type, dataset_id, text_query, limit: asyncio.run(
|
| 211 |
+
search_similar_datasets(dataset_id, limit)
|
| 212 |
+
if search_type == "Dataset ID"
|
| 213 |
+
else search_similar_datasets_by_text(text_query, limit)
|
| 214 |
+
),
|
| 215 |
+
inputs=[search_type, dataset_id, text_query, max_results],
|
| 216 |
+
outputs=results,
|
| 217 |
+
)
|
| 218 |
|
| 219 |
+
with gr.TabItem("RAGatouille Search"):
|
| 220 |
+
ragatouille_interface = create_ragatouille_interface()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
demo.launch()
|
ragatouille_search.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from huggingface_hub import snapshot_download
|
| 4 |
+
from ragatouille import RAGPretrainedModel
|
| 5 |
+
from toolz import unique
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
|
| 8 |
+
# Top-level variables
|
| 9 |
+
INDEX_PATH = Path(".ragatouille/colbert/indexes/my_index_with_ids_and_metadata/")
|
| 10 |
+
REPO_ID = "davanstrien/search-index"
|
| 11 |
+
|
| 12 |
+
INITIAL_QUERY = "hello world"
|
| 13 |
+
DEFAULT_K = 10
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def initialize_index():
|
| 17 |
+
INDEX_PATH.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
snapshot_download(REPO_ID, repo_type="dataset", local_dir=INDEX_PATH)
|
| 19 |
+
rag = RAGPretrainedModel.from_index(INDEX_PATH)
|
| 20 |
+
# Warm up index
|
| 21 |
+
rag.search(INITIAL_QUERY)
|
| 22 |
+
return rag
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def format_results_as_markdown(results: List[Dict[str, Any]]) -> str:
|
| 26 |
+
markdown = ""
|
| 27 |
+
for result in results:
|
| 28 |
+
content = result["content"]
|
| 29 |
+
score = result["score"]
|
| 30 |
+
rank = result["rank"]
|
| 31 |
+
document_id = result["document_id"]
|
| 32 |
+
passage_id = result["passage_id"]
|
| 33 |
+
link = f"https://huggingface.co/datasets/{document_id}"
|
| 34 |
+
|
| 35 |
+
markdown += f"### Result {rank}\n"
|
| 36 |
+
markdown += f"**Score:** {score}\n\n"
|
| 37 |
+
markdown += f"**Document ID:** [{document_id}]({link})\n\n"
|
| 38 |
+
markdown += f"**Passage ID:** {passage_id}\n\n"
|
| 39 |
+
|
| 40 |
+
# Limit initial content display to 1000 characters
|
| 41 |
+
preview = f"{content[:1000]}..." if len(content) > 1000 else content
|
| 42 |
+
markdown += f"{preview}\n\n"
|
| 43 |
+
|
| 44 |
+
# Add expandable section for full content if it's longer than 1000 characters
|
| 45 |
+
if len(content) > 1000:
|
| 46 |
+
markdown += "<details>\n"
|
| 47 |
+
markdown += "<summary>Click to expand full content</summary>\n\n"
|
| 48 |
+
markdown += f"{content}\n\n"
|
| 49 |
+
markdown += "</details>\n\n"
|
| 50 |
+
|
| 51 |
+
markdown += "---\n\n"
|
| 52 |
+
|
| 53 |
+
return markdown
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def search_with_ragatouille(query, k=DEFAULT_K, make_unique=False):
|
| 57 |
+
results = RAG.search(query, k=k)
|
| 58 |
+
if make_unique:
|
| 59 |
+
results = make_results_unique(results)
|
| 60 |
+
return format_results_as_markdown(results)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def make_results_unique(results: List[Dict[str, Any]]):
|
| 64 |
+
unique_results = unique(results, lambda x: x["document_id"])
|
| 65 |
+
return list(unique_results)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def create_ragatouille_interface():
|
| 69 |
+
with gr.Blocks() as ragatouille_demo:
|
| 70 |
+
gr.Markdown("### RAGatouille Dataset Search")
|
| 71 |
+
gr.Markdown(
|
| 72 |
+
"""This interface allows you to search inside dataset cards on the Hub using the [answerai-colbert-small-v1](https://huggingface.co/answerdotai/answerai-colbert-small-v1) ColBERT model via [RAGatouille](https://github.com/AnswerDotAI/RAGatouille). Please be aware that this is an early prototype and may not work as expected!
|
| 73 |
+
|
| 74 |
+
## Notes:
|
| 75 |
+
**Not all datasets are indexed yet!**
|
| 76 |
+
For a dataset to be indexed:
|
| 77 |
+
- It must have a dataset card on the Hub. You can find documentation on how to write a good dataset card [here](https://huggingface.co/docs/hub/datasets-cards).
|
| 78 |
+
- The dataset must have at least 1 like and 1 download
|
| 79 |
+
- The card must be a minimum length (to weed out low quality cards)
|
| 80 |
+
**At the moment the index is refreshed when I decide to do it, so it may not be up to date.** If there is sufficient interest I will implement a daily refresh (give this repo a like if you'd like this feature!)
|
| 81 |
+
Feel free to open a discussion to give feedback or request features 🤗
|
| 82 |
+
"""
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
query = gr.Textbox(label="Query")
|
| 86 |
+
k = gr.Slider(1, 100, value=DEFAULT_K, step=1, label="Number of Results")
|
| 87 |
+
make_unique = gr.Checkbox(False, label="Unique Results")
|
| 88 |
+
search_button = gr.Button("Search")
|
| 89 |
+
search_button.click(
|
| 90 |
+
search_with_ragatouille,
|
| 91 |
+
inputs=[query, k, make_unique],
|
| 92 |
+
outputs=gr.Markdown(label="Results"),
|
| 93 |
+
)
|
| 94 |
+
return ragatouille_demo
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Initialize RAG globally
|
| 98 |
+
RAG = initialize_index()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def main():
|
| 102 |
+
demo = create_ragatouille_interface()
|
| 103 |
+
demo.launch()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
main()
|