File size: 4,887 Bytes
b721898
 
 
 
 
 
 
 
 
b58f415
 
 
 
b721898
 
 
 
 
 
 
b58f415
 
b721898
b58f415
 
ef226d4
c040cff
b58f415
61e95cf
b721898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import numpy as np
import cv2
from PIL import Image
import io
import torch
import clip
import tensorflow as tf
import os




app = FastAPI()

# Load models
tflite_model = tf.lite.Interpreter(model_path="resnet_model.tflite")
tflite_model.allocate_tensors()

# clip_device = "cuda" if torch.cuda.is_available() else "cpu"
# clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device)
clip_device = "cuda" if torch.cuda.is_available() else "cpu"

# Create a writable models directory
# os.makedirs("models", exist_ok=True)
os.makedirs("/tmp/models", exist_ok=True)

clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device, download_root="/tmp/models")


# Class names
class_names = [
    "Nooni", "Nithyapushpa", "Basale", "Pomegranate", "Honge", 
    "Lemon_grass", "Mint", "Betel_Nut", "Nagadali", "Curry_Leaf",
    "Jasmine", "Castor", "Sapota", "Neem", "Ashoka", "Brahmi",
    "Amruta_Balli", "Pappaya", "Pepper", "Wood_sorel", "Gauva",
    "Hibiscus", "Ashwagandha", "Aloevera", "Raktachandini",
    "Insulin", "Bamboo", "Amla", "Arali", "Geranium", "Avacado",
    "Lemon", "Ekka", "Betel", "Henna", "Doddapatre", "Rose",
    "Mango", "Tulasi", "Ganike"
]

plant_pdf_map = {
    "Tulasi": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/tulsi.pdf",
    "Neem": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/neem.pdf",
    "Mint": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/mint.pdf",
    "Aloevera": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/aloevera.pdf"
}

# Helpers
def check_image_quality(image_pil):
    img_array = np.array(image_pil)
    brightness = np.mean(img_array)
    color_std = np.std(img_array)
    too_dark = brightness < 30
    too_bright = brightness > 220
    low_contrast = color_std < 15
    too_small = image_pil.width < 100 or image_pil.height < 100
    is_good = not (too_dark or too_bright or low_contrast or too_small)
    issues = []
    if too_dark: issues.append("Too dark")
    if too_bright: issues.append("Too bright")
    if low_contrast: issues.append("Low contrast")
    if too_small: issues.append("Too small")
    return is_good, issues

def validate_plant_image(image_pil):
    clip_image = clip_preprocess(image_pil).unsqueeze(0).to(clip_device)
    plant_prompts = [
        "a photo of a plant", "a photo of a leaf", "a photo of a green plant",
        "a photo of a medicinal plant", "a photo of herbs"
    ]
    non_plant_prompts = [
        "a photo of a person", "a photo of food", "a document", "a vehicle"
    ]
    all_prompts = plant_prompts + non_plant_prompts
    tokens = clip.tokenize(all_prompts).to(clip_device)

    with torch.no_grad():
        logits, _ = clip_model(clip_image, tokens)
        probs = logits.softmax(dim=-1).cpu().numpy()[0]

    plant_conf = np.mean(probs[:len(plant_prompts)])
    non_plant_conf = np.mean(probs[len(plant_prompts):])
    return plant_conf > non_plant_conf

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        image_bytes = await file.read()
        image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        # Image quality
        is_good, issues = check_image_quality(image_pil)
        if not is_good:
            return JSONResponse(status_code=400, content={"error": "Low image quality", "issues": issues})

        # Plant validation
        if not validate_plant_image(image_pil):
            return JSONResponse(status_code=400, content={"error": "Image does not look like a plant"})

        # Prepare for TFLite
        image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
        image = cv2.resize(image, (224, 224))
        image = image / 255.0
        image = image.astype(np.float32)
        image = np.expand_dims(image, axis=0)

        input_details = tflite_model.get_input_details()
        output_details = tflite_model.get_output_details()

        tflite_model.set_tensor(input_details[0]['index'], image)
        tflite_model.invoke()
        output = tflite_model.get_tensor(output_details[0]['index'])[0]

        pred_idx = int(np.argmax(output))
        pred_name = class_names[pred_idx]
        confidence = float(output[pred_idx])
        top3_idx = np.argsort(output)[-3:][::-1]
        top_preds = {class_names[i]: round(float(output[i]), 4) for i in top3_idx}
        pdf_url = plant_pdf_map.get(pred_name, "https://kampa.karnataka.gov.in/")

        return {
            "prediction": pred_name,
            "confidence": confidence,
            "top_3_predictions": top_preds,
            "pdf_url": pdf_url
        }
    except Exception as e:
        return JSONResponse(status_code=500, content={"error": str(e)})