File size: 3,565 Bytes
e0a0485
 
 
 
 
 
 
d11647f
e0a0485
 
 
 
 
 
 
 
 
 
 
 
 
 
13cabff
d11647f
 
 
e0a0485
 
 
 
 
 
d11647f
 
 
e0a0485
 
d11647f
e0a0485
 
 
 
d11647f
e0a0485
 
 
 
 
d11647f
 
e0a0485
d11647f
e0a0485
 
 
 
 
 
d11647f
e0a0485
d11647f
 
e0a0485
d11647f
 
 
 
 
e0a0485
 
 
d11647f
 
9e62b35
b721a24
 
e0a0485
 
 
 
 
 
d11647f
 
 
 
 
 
 
 
 
 
e0a0485
d11647f
 
 
 
 
e0a0485
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from deep_translator import GoogleTranslator
from datasets import load_dataset

print("Downloading dataset and initializing model...")
dataset = load_dataset('JotDe/birds')

print("Loading embeddings...")
df = pd.read_parquet("bird_embeddings.parquet")
subset_indices = df['dataset_index'].tolist()
subset_labels = df['label'].tolist()

feature_cols = [c for c in df.columns if c not in ['dataset_index', 'label']]
embeddings = df[feature_cols].values
dataset_embeddings = torch.tensor(embeddings, device='cpu')

model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cpu')
translator = GoogleTranslator(source='auto', target='en')

def get_recommendations(text_input, top_k=3):
    if not text_input or text_input.strip() == "":
        return []
        
    query_embedding = model.encode(text_input, convert_to_tensor=True, device='cpu')
    similarities = util.cos_sim(query_embedding, dataset_embeddings)[0]
    top_indices = similarities.argsort(descending=True)[:top_k]
    
    recommendations = []
    for idx in top_indices:
        i = idx.item() 
        # Crucial fix for numpy types!
        original_idx = int(subset_indices[i])
        
        img = dataset['train'][original_idx]['image']
        label_id = int(subset_labels[i])
        species_name = dataset['train'].features['label'].int2str(label_id)
        
        recommendations.append({
            "image": img,
            "label": species_name
        })
    return recommendations

def gradio_interface(text_input):
    if not text_input or text_input.strip() == "":
        return None, "### Please enter a description.", None, "", None, ""

    if any(char.isdigit() for char in text_input):
        return None, "### Please put a bird type (example: blue bird)", None, "", None, ""

    try:
        english_query = translator.translate(text_input)
    except:
        english_query = text_input

    results = get_recommendations(text_input=english_query, top_k=3)

    if not results or len(results) < 3:
        return None, "### No matches found.", None, "", None, ""

    name1 = f"### {results[0]['label']}"
    name2 = f"### {results[1]['label']}"
    name3 = f"### {results[2]['label']}"

    return results[0]['image'], name1, results[1]['image'], name2, results[2]['image'], name3

with gr.Blocks(title="🐦 Smart Bird Tracker") as demo:
    gr.Markdown("# 🐦 Smart Bird Tracker")
    gr.Markdown("Describe the bird you are looking for in **English**, **Spanish**, or **Hebrew**, and the AI will find the closest matches!")
    
    
     
    
    with gr.Row():
        with gr.Column(scale=1):
            text_in = gr.Textbox(label="Describe the bird", placeholder="Type your description here...")
            submit_btn = gr.Button("Find Birds", variant="primary")

        with gr.Column(scale=2):
            with gr.Row():
                with gr.Column():
                    out_img1 = gr.Image(label="Top Match 1")
                    out_name1 = gr.Markdown()
                with gr.Column():
                    out_img2 = gr.Image(label="Top Match 2")
                    out_name2 = gr.Markdown()
                with gr.Column():
                    out_img3 = gr.Image(label="Top Match 3")
                    out_name3 = gr.Markdown()

    submit_btn.click(
        fn=gradio_interface,
        inputs=[text_in],
        outputs=[out_img1, out_name1, out_img2, out_name2, out_img3, out_name3]
    )

demo.launch()