Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import io | |
| import torch | |
| import clip | |
| import tensorflow as tf | |
| import os | |
| app = FastAPI() | |
| # Load models | |
| tflite_model = tf.lite.Interpreter(model_path="resnet_model.tflite") | |
| tflite_model.allocate_tensors() | |
| # clip_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device) | |
| clip_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Create a writable models directory | |
| # os.makedirs("models", exist_ok=True) | |
| os.makedirs("/tmp/models", exist_ok=True) | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device, download_root="/tmp/models") | |
| # Class names | |
| class_names = [ | |
| "Nooni", "Nithyapushpa", "Basale", "Pomegranate", "Honge", | |
| "Lemon_grass", "Mint", "Betel_Nut", "Nagadali", "Curry_Leaf", | |
| "Jasmine", "Castor", "Sapota", "Neem", "Ashoka", "Brahmi", | |
| "Amruta_Balli", "Pappaya", "Pepper", "Wood_sorel", "Gauva", | |
| "Hibiscus", "Ashwagandha", "Aloevera", "Raktachandini", | |
| "Insulin", "Bamboo", "Amla", "Arali", "Geranium", "Avacado", | |
| "Lemon", "Ekka", "Betel", "Henna", "Doddapatre", "Rose", | |
| "Mango", "Tulasi", "Ganike" | |
| ] | |
| plant_pdf_map = { | |
| "Tulasi": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/tulsi.pdf", | |
| "Neem": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/neem.pdf", | |
| "Mint": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/mint.pdf", | |
| "Aloevera": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/aloevera.pdf" | |
| } | |
| # Helpers | |
| def check_image_quality(image_pil): | |
| img_array = np.array(image_pil) | |
| brightness = np.mean(img_array) | |
| color_std = np.std(img_array) | |
| too_dark = brightness < 30 | |
| too_bright = brightness > 220 | |
| low_contrast = color_std < 15 | |
| too_small = image_pil.width < 100 or image_pil.height < 100 | |
| is_good = not (too_dark or too_bright or low_contrast or too_small) | |
| issues = [] | |
| if too_dark: issues.append("Too dark") | |
| if too_bright: issues.append("Too bright") | |
| if low_contrast: issues.append("Low contrast") | |
| if too_small: issues.append("Too small") | |
| return is_good, issues | |
| def validate_plant_image(image_pil): | |
| clip_image = clip_preprocess(image_pil).unsqueeze(0).to(clip_device) | |
| plant_prompts = [ | |
| "a photo of a plant", "a photo of a leaf", "a photo of a green plant", | |
| "a photo of a medicinal plant", "a photo of herbs" | |
| ] | |
| non_plant_prompts = [ | |
| "a photo of a person", "a photo of food", "a document", "a vehicle" | |
| ] | |
| all_prompts = plant_prompts + non_plant_prompts | |
| tokens = clip.tokenize(all_prompts).to(clip_device) | |
| with torch.no_grad(): | |
| logits, _ = clip_model(clip_image, tokens) | |
| probs = logits.softmax(dim=-1).cpu().numpy()[0] | |
| plant_conf = np.mean(probs[:len(plant_prompts)]) | |
| non_plant_conf = np.mean(probs[len(plant_prompts):]) | |
| return plant_conf > non_plant_conf | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| image_bytes = await file.read() | |
| image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Image quality | |
| is_good, issues = check_image_quality(image_pil) | |
| if not is_good: | |
| return JSONResponse(status_code=400, content={"error": "Low image quality", "issues": issues}) | |
| # Plant validation | |
| if not validate_plant_image(image_pil): | |
| return JSONResponse(status_code=400, content={"error": "Image does not look like a plant"}) | |
| # Prepare for TFLite | |
| image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| image = cv2.resize(image, (224, 224)) | |
| image = image / 255.0 | |
| image = image.astype(np.float32) | |
| image = np.expand_dims(image, axis=0) | |
| input_details = tflite_model.get_input_details() | |
| output_details = tflite_model.get_output_details() | |
| tflite_model.set_tensor(input_details[0]['index'], image) | |
| tflite_model.invoke() | |
| output = tflite_model.get_tensor(output_details[0]['index'])[0] | |
| pred_idx = int(np.argmax(output)) | |
| pred_name = class_names[pred_idx] | |
| confidence = float(output[pred_idx]) | |
| top3_idx = np.argsort(output)[-3:][::-1] | |
| top_preds = {class_names[i]: round(float(output[i]), 4) for i in top3_idx} | |
| pdf_url = plant_pdf_map.get(pred_name, "https://kampa.karnataka.gov.in/") | |
| return { | |
| "prediction": pred_name, | |
| "confidence": confidence, | |
| "top_3_predictions": top_preds, | |
| "pdf_url": pdf_url | |
| } | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |