Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer, util | |
| from deep_translator import GoogleTranslator | |
| from datasets import load_dataset | |
| print("Downloading dataset and initializing model...") | |
| dataset = load_dataset('JotDe/birds') | |
| print("Loading embeddings...") | |
| df = pd.read_parquet("bird_embeddings.parquet") | |
| subset_indices = df['dataset_index'].tolist() | |
| subset_labels = df['label'].tolist() | |
| feature_cols = [c for c in df.columns if c not in ['dataset_index', 'label']] | |
| embeddings = df[feature_cols].values | |
| dataset_embeddings = torch.tensor(embeddings, device='cpu') | |
| model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cpu') | |
| translator = GoogleTranslator(source='auto', target='en') | |
| def get_recommendations(text_input, top_k=3): | |
| if not text_input or text_input.strip() == "": | |
| return [] | |
| query_embedding = model.encode(text_input, convert_to_tensor=True, device='cpu') | |
| similarities = util.cos_sim(query_embedding, dataset_embeddings)[0] | |
| top_indices = similarities.argsort(descending=True)[:top_k] | |
| recommendations = [] | |
| for idx in top_indices: | |
| i = idx.item() | |
| # Crucial fix for numpy types! | |
| original_idx = int(subset_indices[i]) | |
| img = dataset['train'][original_idx]['image'] | |
| label_id = int(subset_labels[i]) | |
| species_name = dataset['train'].features['label'].int2str(label_id) | |
| recommendations.append({ | |
| "image": img, | |
| "label": species_name | |
| }) | |
| return recommendations | |
| def gradio_interface(text_input): | |
| if not text_input or text_input.strip() == "": | |
| return None, "### Please enter a description.", None, "", None, "" | |
| if any(char.isdigit() for char in text_input): | |
| return None, "### Please put a bird type (example: blue bird)", None, "", None, "" | |
| try: | |
| english_query = translator.translate(text_input) | |
| except: | |
| english_query = text_input | |
| results = get_recommendations(text_input=english_query, top_k=3) | |
| if not results or len(results) < 3: | |
| return None, "### No matches found.", None, "", None, "" | |
| name1 = f"### {results[0]['label']}" | |
| name2 = f"### {results[1]['label']}" | |
| name3 = f"### {results[2]['label']}" | |
| return results[0]['image'], name1, results[1]['image'], name2, results[2]['image'], name3 | |
| with gr.Blocks(title="🐦 Smart Bird Tracker") as demo: | |
| gr.Markdown("# 🐦 Smart Bird Tracker") | |
| gr.Markdown("Describe the bird you are looking for in **English**, **Spanish**, or **Hebrew**, and the AI will find the closest matches!") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_in = gr.Textbox(label="Describe the bird", placeholder="Type your description here...") | |
| submit_btn = gr.Button("Find Birds", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| out_img1 = gr.Image(label="Top Match 1") | |
| out_name1 = gr.Markdown() | |
| with gr.Column(): | |
| out_img2 = gr.Image(label="Top Match 2") | |
| out_name2 = gr.Markdown() | |
| with gr.Column(): | |
| out_img3 = gr.Image(label="Top Match 3") | |
| out_name3 = gr.Markdown() | |
| submit_btn.click( | |
| fn=gradio_interface, | |
| inputs=[text_in], | |
| outputs=[out_img1, out_name1, out_img2, out_name2, out_img3, out_name3] | |
| ) | |
| demo.launch() |