mkfinder-ai / app.py
Krish025's picture
Update app.py
1b7956f verified
"""
MKfinder AI API - HuggingFace Spaces
EXACT same logic as predict.py (which worked on XAMPP)
Just wrapped in FastAPI for HuggingFace serving
"""
import os, re, sys, io, json, warnings, logging
warnings.filterwarnings("ignore")
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["HF_HUB_VERBOSITY"] = "error"
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
HF_MODEL = "prithivMLmods/Bird-Species-Classifier-526"
CONFIDENCE_THRESHOLD = 8.0
IMAGENET_BIRD_INDICES = set(range(7, 24))
BIRD_WORDS = [
"cock","hen","ostrich","brambling","goldfinch","finch","junco",
"bunting","robin","bulbul","jay","magpie","chickadee","ouzel",
"kite","eagle","vulture","bird","hawk","owl","duck","goose",
"crane","heron","woodpecker","wren","thrush","warbler","dove",
"pigeon","parrot","penguin","flamingo","pelican","stork","swift",
"swallow","kingfisher","quail","peacock","toucan","hornbill",
"macaw","cockatoo","ibis","albatross","raven","falcon","osprey",
"kestrel","grouse","sparrow","hummingbird","cardinal",
]
NON_BIRD_KEYWORDS = [
"sunglass","sunglasses","glasses","spectacles",
"suit","tie","bow tie","bolo tie","windsor tie","jersey","lab coat",
"dress","shirt","skirt","swimwear","bra","mitten","glove","stole",
"hair","beard","lipstick","wig","bandage","mask","neck brace",
"seat belt","crutch","snorkel","apron","cardigan","cloak","poncho",
"car","truck","bus","bicycle","motorcycle","airplane","train","boat",
"cab","minivan","ambulance","tractor","forklift","scooter",
"laptop","phone","keyboard","monitor","television","camera","remote",
"cellular telephone","ipod","projector","printer","modem","mouse",
"pizza","burger","sandwich","bottle","cup","bowl","plate","vase",
"wine bottle","beer bottle","water bottle","pop bottle",
"coffee mug","pitcher","pot","pan","ladle","spatula","tongs",
"chair","sofa","table","desk","bed","shelf","bookcase","filing cabinet",
"couch","rocking chair","stool","wardrobe","chest","safe",
"book","pen","pencil","scissors","ruler","hammer","nail","screwdriver",
"wrench","plier","axe","shovel","hatchet","chisel",
"plunger","drumstick","Band Aid","eraser","torch","lighter",
"umbrella","crutch","stretcher","bucket","barrel","cistern",
"radio","dial telephone","clock","stopwatch","hourglass",
"pillow","blanket","bath towel","shower curtain","toilet",
"candle","lamp","spotlight","chandelier","lantern",
"dog","cat","horse","cow","bear","lion","tiger","elephant","monkey",
"snake","lizard","frog","spider","scorpion","fish","shark","whale",
"rabbit","hamster","squirrel","fox","wolf","deer","zebra","giraffe",
"building","house","bridge","tower","wall","window","door",
"fountain","statue","column","pedestal","streetcar","barn",
"teddy","toy","doll","stuffed","puppet","rocking horse",
"sari","saree","fabric","textile","weaving","wool","silk","velvet",
"paper","envelope","toilet paper","newspaper","comic book","menu",
"receipt","document","letter","binder","notebook","notepad",
"measuring cup","measuring stick","rule","abacus","cash machine",
"hand","arm","leg","foot","thumb","finger","nail","knee","elbow",
"necklace","bracelet","ring","earring","bangle","jewel","crown",
"wall clock","digital clock","sundial","hourglass","stopwatch",
]
def is_bird_label(label):
label_lower = label.lower()
long_words = [w for w in BIRD_WORDS if len(w) > 4]
if any(w in label_lower for w in long_words):
return True
short_words = [w for w in BIRD_WORDS if len(w) <= 4]
for w in short_words:
if re.search(r'\b' + re.escape(w) + r'\b', label_lower):
return True
return False
def not_bird_message(label):
l = label.lower()
if any(w in l for w in ["sunglass","glasses","suit","tie","dress","shirt",
"jersey","hair","beard","person","human","face",
"coat","skirt","swimwear","neck","crutch","snorkel",
"seat belt","stole","bandage","mask","wig","lipstick",
"apron","cardigan","cloak","poncho","mitten","glove"]):
return "This looks like a photo of a person. Please upload a clear photo of a bird."
if any(w in l for w in ["teddy","toy","doll","stuffed","puppet"]):
return "This looks like a photo of a toy or object. Please upload a clear photo of a bird."
if any(w in l for w in ["dog","cat","horse","cow","bear","lion","tiger",
"elephant","monkey","snake","lizard","frog","fish",
"rabbit","hamster","squirrel","fox","wolf","deer"]):
return "This looks like a photo of an animal, but not a bird. Please upload a bird photo."
if any(w in l for w in ["car","truck","bus","bicycle","motorcycle",
"airplane","train","boat","phone","laptop","scooter"]):
return "This looks like a photo of a vehicle or device. Please upload a bird photo."
if any(w in l for w in ["bottle","cup","bowl","plate","vase","chair","lamp",
"sofa","table","desk","building","house","statue",
"fountain","pillow","candle","umbrella","bucket"]):
return "This looks like a photo of an object or place. Please upload a bird photo."
return "This doesn't appear to be a bird photo. Please upload a clear photo of a bird."
# ── Cached models ─────────────────────────────────────────────────────────────
_mn_model = _mn_labels = _mn_transform = None
_processor = _siglip = None
def get_mobilenet():
global _mn_model, _mn_labels, _mn_transform
if _mn_model is None:
import torchvision.models as tvm
import torchvision.transforms as T
buf = io.StringIO()
sys.stdout = sys.stderr = buf
try:
weights = tvm.MobileNet_V3_Small_Weights.DEFAULT
_mn_model = tvm.mobilenet_v3_small(weights=weights)
_mn_labels = weights.meta["categories"]
finally:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
_mn_model.eval()
_mn_transform = T.Compose([
T.Resize(256), T.CenterCrop(224), T.ToTensor(),
T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
return _mn_model, _mn_labels, _mn_transform
def get_siglip():
global _processor, _siglip
if _siglip is None:
from transformers import AutoImageProcessor, SiglipForImageClassification
buf = io.StringIO()
sys.stdout = sys.stderr = buf
try:
_processor = AutoImageProcessor.from_pretrained(HF_MODEL, use_fast=False)
_siglip = SiglipForImageClassification.from_pretrained(HF_MODEL)
finally:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
_siglip.eval()
return _processor, _siglip
def mobilenet_check(img):
"""
EXACT same logic as predict.py mobilenet_check() β€” proven working on XAMPP.
Returns (is_not_bird: bool, label: str, conf: float)
"""
import torch
import torch.nn.functional as F
mn, labels, transform = get_mobilenet()
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
probs = F.softmax(mn(tensor), dim=1).squeeze()
top_probs, top_idxs = probs.topk(20)
# STEP 1 β€” if ANY top-20 result is a bird index or bird word β†’ PASS immediately
for i in range(20):
conf = top_probs[i].item() * 100
if conf < 0.5:
break
idx = top_idxs[i].item()
label = labels[idx]
if idx in IMAGENET_BIRD_INDICES and conf >= 1.0:
return False, label, conf
if is_bird_label(label) and conf >= 1.0:
return False, label, conf
# STEP 2 β€” no bird found β€” check for non-bird keywords at 10%
for i in range(20):
conf = top_probs[i].item() * 100
if conf < 0.5:
break
label = labels[top_idxs[i].item()].lower()
full = labels[top_idxs[i].item()]
for kw in NON_BIRD_KEYWORDS:
if kw.lower() in label and conf >= 3.0:
return True, full, conf
# STEP 3 β€” CATCH-ALL: no bird evidence at all + top conf >= 20% β†’ block
top_conf = top_probs[0].item() * 100
top_label = labels[top_idxs[0].item()]
if top_conf >= 8.0:
return True, top_label, top_conf
# STEP 4 β€” very uncertain β€” pass to Siglip (could be exotic bird)
return False, top_label, top_conf
# ── FastAPI endpoints ─────────────────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "MKfinder AI", "version": "5.0"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/identify")
async def identify(image: UploadFile = File(...)):
try:
import torch
import torch.nn.functional as F
from PIL import Image as PILImage
data = await image.read()
img = PILImage.open(io.BytesIO(data)).convert("RGB")
# ── PIXEL CHECK: catch receipts, white paper, solid color images ──────
import numpy as np
arr = np.array(img.resize((150, 150)), dtype=float) / 255.0
r, g, b = arr[:,:,0], arr[:,:,1], arr[:,:,2]
brightness = (r + g + b) / 3.0
cmax = np.maximum(np.maximum(r,g),b)
cmin = np.minimum(np.minimum(r,g),b)
saturation = (cmax - cmin) / (cmax + 1e-9)
white_pct = float(np.sum(brightness > 0.85)) / brightness.size * 100
low_sat_pct = float(np.sum(saturation < 0.12)) / saturation.size * 100
dark_pct = float(np.sum(brightness < 0.12)) / brightness.size * 100
# Receipt/bill/paper: very white + very low saturation
if white_pct > 55.0 and low_sat_pct > 60.0:
return {"success": False,
"error": "This looks like a document or paper. Please upload a clear photo of a bird.",
"error_code": "NOT_A_BIRD"}
# Skin tone check β€” catches selfies and group photos
h = np.zeros_like(r)
diff = cmax - cmin + 1e-9
mr=(cmax==r); mg=(cmax==g)&~mr; mb=(cmax==b)&~mr&~mg
h[mr]=(60.0*((g[mr]-b[mr])/diff[mr]))%360
h[mg]=(60.0*((b[mg]-r[mg])/diff[mg]))+120
h[mb]=(60.0*((r[mb]-g[mb])/diff[mb]))+240
skin = (((h<=28)|(h>=335)) & (saturation>=0.10) & (saturation<=0.68) & (brightness>=0.35))
skin_pct = float(np.sum(skin)) / skin.size * 100
rows, cols = r.shape
face_region = skin[:int(rows*0.6), int(cols*0.15):int(cols*0.85)]
face_pct = float(np.sum(face_region)) / face_region.size * 100
if skin_pct > 18.0 and face_pct > 12.0:
return {"success": False,
"error": "This looks like a photo of a person. Please upload a clear photo of a bird.",
"error_code": "NOT_A_BIRD"}
# ── STAGE 1: MobileNet β€” exact same logic as predict.py ───────────────
is_not_bird, mn_label, mn_conf = mobilenet_check(img)
if is_not_bird:
return {
"success": False,
"error": not_bird_message(mn_label),
"error_code": "NOT_A_BIRD"
}
# ── STAGE 2: Siglip bird identification ───────────────────────────────
processor, siglip = get_siglip()
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
probs = F.softmax(siglip(**inputs).logits, dim=1).squeeze().tolist()
id2label = siglip.config.id2label
score_map = {
id2label.get(i, f"Class_{i}").upper().strip(): float(p) * 100
for i, p in enumerate(probs)
}
top_label = max(score_map, key=score_map.get)
top_conf = score_map[top_label]
# Siglip guard β€” non-bird images score below 3.5% across all 526 species
if top_conf < 3.5:
return {
"success": False,
"error": "This doesn't appear to be a bird photo. Please upload a clear photo of a bird.",
"error_code": "NOT_A_BIRD"
}
# Title case for DB matching: "SNOWY OWL" β†’ "Snowy Owl"
top_label_nice = top_label.title()
if top_conf >= CONFIDENCE_THRESHOLD:
return {
"success": True,
"species": top_label_nice,
"confidence": round(top_conf, 2),
"mode": "ai_model"
}
else:
return {
"success": True,
"species": "Unknown",
"confidence": round(top_conf, 2),
"mode": "ai_model_low_confidence"
}
except Exception as e:
return {"success": False, "error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)