Spaces:
Sleeping
Sleeping
File size: 4,887 Bytes
b721898 b58f415 b721898 b58f415 b721898 b58f415 ef226d4 c040cff b58f415 61e95cf b721898 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import numpy as np
import cv2
from PIL import Image
import io
import torch
import clip
import tensorflow as tf
import os
app = FastAPI()
# Load models
tflite_model = tf.lite.Interpreter(model_path="resnet_model.tflite")
tflite_model.allocate_tensors()
# clip_device = "cuda" if torch.cuda.is_available() else "cpu"
# clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device)
clip_device = "cuda" if torch.cuda.is_available() else "cpu"
# Create a writable models directory
# os.makedirs("models", exist_ok=True)
os.makedirs("/tmp/models", exist_ok=True)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device, download_root="/tmp/models")
# Class names
class_names = [
"Nooni", "Nithyapushpa", "Basale", "Pomegranate", "Honge",
"Lemon_grass", "Mint", "Betel_Nut", "Nagadali", "Curry_Leaf",
"Jasmine", "Castor", "Sapota", "Neem", "Ashoka", "Brahmi",
"Amruta_Balli", "Pappaya", "Pepper", "Wood_sorel", "Gauva",
"Hibiscus", "Ashwagandha", "Aloevera", "Raktachandini",
"Insulin", "Bamboo", "Amla", "Arali", "Geranium", "Avacado",
"Lemon", "Ekka", "Betel", "Henna", "Doddapatre", "Rose",
"Mango", "Tulasi", "Ganike"
]
plant_pdf_map = {
"Tulasi": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/tulsi.pdf",
"Neem": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/neem.pdf",
"Mint": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/mint.pdf",
"Aloevera": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/aloevera.pdf"
}
# Helpers
def check_image_quality(image_pil):
img_array = np.array(image_pil)
brightness = np.mean(img_array)
color_std = np.std(img_array)
too_dark = brightness < 30
too_bright = brightness > 220
low_contrast = color_std < 15
too_small = image_pil.width < 100 or image_pil.height < 100
is_good = not (too_dark or too_bright or low_contrast or too_small)
issues = []
if too_dark: issues.append("Too dark")
if too_bright: issues.append("Too bright")
if low_contrast: issues.append("Low contrast")
if too_small: issues.append("Too small")
return is_good, issues
def validate_plant_image(image_pil):
clip_image = clip_preprocess(image_pil).unsqueeze(0).to(clip_device)
plant_prompts = [
"a photo of a plant", "a photo of a leaf", "a photo of a green plant",
"a photo of a medicinal plant", "a photo of herbs"
]
non_plant_prompts = [
"a photo of a person", "a photo of food", "a document", "a vehicle"
]
all_prompts = plant_prompts + non_plant_prompts
tokens = clip.tokenize(all_prompts).to(clip_device)
with torch.no_grad():
logits, _ = clip_model(clip_image, tokens)
probs = logits.softmax(dim=-1).cpu().numpy()[0]
plant_conf = np.mean(probs[:len(plant_prompts)])
non_plant_conf = np.mean(probs[len(plant_prompts):])
return plant_conf > non_plant_conf
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Image quality
is_good, issues = check_image_quality(image_pil)
if not is_good:
return JSONResponse(status_code=400, content={"error": "Low image quality", "issues": issues})
# Plant validation
if not validate_plant_image(image_pil):
return JSONResponse(status_code=400, content={"error": "Image does not look like a plant"})
# Prepare for TFLite
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
image = cv2.resize(image, (224, 224))
image = image / 255.0
image = image.astype(np.float32)
image = np.expand_dims(image, axis=0)
input_details = tflite_model.get_input_details()
output_details = tflite_model.get_output_details()
tflite_model.set_tensor(input_details[0]['index'], image)
tflite_model.invoke()
output = tflite_model.get_tensor(output_details[0]['index'])[0]
pred_idx = int(np.argmax(output))
pred_name = class_names[pred_idx]
confidence = float(output[pred_idx])
top3_idx = np.argsort(output)[-3:][::-1]
top_preds = {class_names[i]: round(float(output[i]), 4) for i in top3_idx}
pdf_url = plant_pdf_map.get(pred_name, "https://kampa.karnataka.gov.in/")
return {
"prediction": pred_name,
"confidence": confidence,
"top_3_predictions": top_preds,
"pdf_url": pdf_url
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
|