Spaces:
Sleeping
Sleeping
File size: 3,565 Bytes
e0a0485 d11647f e0a0485 13cabff d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f e0a0485 d11647f 9e62b35 b721a24 e0a0485 d11647f e0a0485 d11647f e0a0485 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | import gradio as gr
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from deep_translator import GoogleTranslator
from datasets import load_dataset
print("Downloading dataset and initializing model...")
dataset = load_dataset('JotDe/birds')
print("Loading embeddings...")
df = pd.read_parquet("bird_embeddings.parquet")
subset_indices = df['dataset_index'].tolist()
subset_labels = df['label'].tolist()
feature_cols = [c for c in df.columns if c not in ['dataset_index', 'label']]
embeddings = df[feature_cols].values
dataset_embeddings = torch.tensor(embeddings, device='cpu')
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cpu')
translator = GoogleTranslator(source='auto', target='en')
def get_recommendations(text_input, top_k=3):
if not text_input or text_input.strip() == "":
return []
query_embedding = model.encode(text_input, convert_to_tensor=True, device='cpu')
similarities = util.cos_sim(query_embedding, dataset_embeddings)[0]
top_indices = similarities.argsort(descending=True)[:top_k]
recommendations = []
for idx in top_indices:
i = idx.item()
# Crucial fix for numpy types!
original_idx = int(subset_indices[i])
img = dataset['train'][original_idx]['image']
label_id = int(subset_labels[i])
species_name = dataset['train'].features['label'].int2str(label_id)
recommendations.append({
"image": img,
"label": species_name
})
return recommendations
def gradio_interface(text_input):
if not text_input or text_input.strip() == "":
return None, "### Please enter a description.", None, "", None, ""
if any(char.isdigit() for char in text_input):
return None, "### Please put a bird type (example: blue bird)", None, "", None, ""
try:
english_query = translator.translate(text_input)
except:
english_query = text_input
results = get_recommendations(text_input=english_query, top_k=3)
if not results or len(results) < 3:
return None, "### No matches found.", None, "", None, ""
name1 = f"### {results[0]['label']}"
name2 = f"### {results[1]['label']}"
name3 = f"### {results[2]['label']}"
return results[0]['image'], name1, results[1]['image'], name2, results[2]['image'], name3
with gr.Blocks(title="🐦 Smart Bird Tracker") as demo:
gr.Markdown("# 🐦 Smart Bird Tracker")
gr.Markdown("Describe the bird you are looking for in **English**, **Spanish**, or **Hebrew**, and the AI will find the closest matches!")
with gr.Row():
with gr.Column(scale=1):
text_in = gr.Textbox(label="Describe the bird", placeholder="Type your description here...")
submit_btn = gr.Button("Find Birds", variant="primary")
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
out_img1 = gr.Image(label="Top Match 1")
out_name1 = gr.Markdown()
with gr.Column():
out_img2 = gr.Image(label="Top Match 2")
out_name2 = gr.Markdown()
with gr.Column():
out_img3 = gr.Image(label="Top Match 3")
out_name3 = gr.Markdown()
submit_btn.click(
fn=gradio_interface,
inputs=[text_in],
outputs=[out_img1, out_name1, out_img2, out_name2, out_img3, out_name3]
)
demo.launch() |