submit_classification / app /models /animal_vision.py
drrobot9's picture
Update app/models/animal_vision.py
f703339 verified
# app/models/animal_vision.py
import faiss
import torch
import open_clip
import numpy as np
from PIL import Image
from app.models.llm import explain_species
from app.utils.config import (
DEVICE,
BIOCLIP_MODEL_ID,
BIOCLIP_INDEX_PATH,
ANIMAL_SPECIES_LIST,
TOP_K_ANIMALS
)
model, _, preprocess = open_clip.create_model_and_transforms(
f"hf-hub:{BIOCLIP_MODEL_ID}"
)
model = model.to(DEVICE)
model.eval()
index = faiss.read_index(str(BIOCLIP_INDEX_PATH))
with open(ANIMAL_SPECIES_LIST, "r", encoding="utf-8") as f:
SPECIES = [line.strip() for line in f]
@torch.no_grad()
def predict_animal(image: Image.Image):
"""
Returns:
{
"species": str,
"common_name": str | None,
"confidence": float,
"top_k": list,
"description": str
}
"""
image_tensor = preprocess(image.convert("RGB"))
image_tensor = image_tensor.unsqueeze(0).to(DEVICE)
image_features = model.encode_image(image_tensor)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
image_np = image_features.cpu().numpy().astype("float32")
scores, indices = index.search(image_np, TOP_K_ANIMALS)
results = []
for idx, score in zip(indices[0], scores[0]):
results.append({
"species": SPECIES[idx],
"similarity": float(score)
})
best = results[0]
llm_result = explain_species(
species=best["species"],
confidence=best["similarity"],
domain="animal",
top_k=results
)
return {
"species": best["species"],
"common_name": llm_result["common_name"],
"confidence": best["similarity"],
#"top_k": results,
# "description": llm_result["description"],
}