Assignment_3 / app.py
Yoel125's picture
Update app.py
b721a24 verified
Raw
History Blame Contribute Delete
3.57 kB
import gradio as gr
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from deep_translator import GoogleTranslator
from datasets import load_dataset
print("Downloading dataset and initializing model...")
dataset = load_dataset('JotDe/birds')
print("Loading embeddings...")
df = pd.read_parquet("bird_embeddings.parquet")
subset_indices = df['dataset_index'].tolist()
subset_labels = df['label'].tolist()
feature_cols = [c for c in df.columns if c not in ['dataset_index', 'label']]
embeddings = df[feature_cols].values
dataset_embeddings = torch.tensor(embeddings, device='cpu')
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cpu')
translator = GoogleTranslator(source='auto', target='en')
def get_recommendations(text_input, top_k=3):
if not text_input or text_input.strip() == "":
return []
query_embedding = model.encode(text_input, convert_to_tensor=True, device='cpu')
similarities = util.cos_sim(query_embedding, dataset_embeddings)[0]
top_indices = similarities.argsort(descending=True)[:top_k]
recommendations = []
for idx in top_indices:
i = idx.item()
# Crucial fix for numpy types!
original_idx = int(subset_indices[i])
img = dataset['train'][original_idx]['image']
label_id = int(subset_labels[i])
species_name = dataset['train'].features['label'].int2str(label_id)
recommendations.append({
"image": img,
"label": species_name
})
return recommendations
def gradio_interface(text_input):
if not text_input or text_input.strip() == "":
return None, "### Please enter a description.", None, "", None, ""
if any(char.isdigit() for char in text_input):
return None, "### Please put a bird type (example: blue bird)", None, "", None, ""
try:
english_query = translator.translate(text_input)
except:
english_query = text_input
results = get_recommendations(text_input=english_query, top_k=3)
if not results or len(results) < 3:
return None, "### No matches found.", None, "", None, ""
name1 = f"### {results[0]['label']}"
name2 = f"### {results[1]['label']}"
name3 = f"### {results[2]['label']}"
return results[0]['image'], name1, results[1]['image'], name2, results[2]['image'], name3
with gr.Blocks(title="🐦 Smart Bird Tracker") as demo:
gr.Markdown("# 🐦 Smart Bird Tracker")
gr.Markdown("Describe the bird you are looking for in **English**, **Spanish**, or **Hebrew**, and the AI will find the closest matches!")
with gr.Row():
with gr.Column(scale=1):
text_in = gr.Textbox(label="Describe the bird", placeholder="Type your description here...")
submit_btn = gr.Button("Find Birds", variant="primary")
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
out_img1 = gr.Image(label="Top Match 1")
out_name1 = gr.Markdown()
with gr.Column():
out_img2 = gr.Image(label="Top Match 2")
out_name2 = gr.Markdown()
with gr.Column():
out_img3 = gr.Image(label="Top Match 3")
out_name3 = gr.Markdown()
submit_btn.click(
fn=gradio_interface,
inputs=[text_in],
outputs=[out_img1, out_name1, out_img2, out_name2, out_img3, out_name3]
)
demo.launch()