omkar90 commited on
Commit
b721898
·
verified ·
1 Parent(s): 15f0a36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -118
app.py CHANGED
@@ -1,118 +1,125 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
- import io
7
- import torch
8
- import clip
9
- import tensorflow as tf
10
-
11
- app = FastAPI()
12
-
13
- # Load models
14
- tflite_model = tf.lite.Interpreter(model_path="resnet_model.tflite")
15
- tflite_model.allocate_tensors()
16
-
17
- clip_model, clip_preprocess, clip_device = clip.load("ViT-B/32", device="cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- # Class names
20
- class_names = [
21
- "Nooni", "Nithyapushpa", "Basale", "Pomegranate", "Honge",
22
- "Lemon_grass", "Mint", "Betel_Nut", "Nagadali", "Curry_Leaf",
23
- "Jasmine", "Castor", "Sapota", "Neem", "Ashoka", "Brahmi",
24
- "Amruta_Balli", "Pappaya", "Pepper", "Wood_sorel", "Gauva",
25
- "Hibiscus", "Ashwagandha", "Aloevera", "Raktachandini",
26
- "Insulin", "Bamboo", "Amla", "Arali", "Geranium", "Avacado",
27
- "Lemon", "Ekka", "Betel", "Henna", "Doddapatre", "Rose",
28
- "Mango", "Tulasi", "Ganike"
29
- ]
30
-
31
- plant_pdf_map = {
32
- "Tulasi": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/tulsi.pdf",
33
- "Neem": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/neem.pdf",
34
- "Mint": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/mint.pdf",
35
- "Aloevera": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/aloevera.pdf"
36
- }
37
-
38
- # Helpers
39
- def check_image_quality(image_pil):
40
- img_array = np.array(image_pil)
41
- brightness = np.mean(img_array)
42
- color_std = np.std(img_array)
43
- too_dark = brightness < 30
44
- too_bright = brightness > 220
45
- low_contrast = color_std < 15
46
- too_small = image_pil.width < 100 or image_pil.height < 100
47
- is_good = not (too_dark or too_bright or low_contrast or too_small)
48
- issues = []
49
- if too_dark: issues.append("Too dark")
50
- if too_bright: issues.append("Too bright")
51
- if low_contrast: issues.append("Low contrast")
52
- if too_small: issues.append("Too small")
53
- return is_good, issues
54
-
55
- def validate_plant_image(image_pil):
56
- clip_image = clip_preprocess(image_pil).unsqueeze(0).to(clip_device)
57
- plant_prompts = [
58
- "a photo of a plant", "a photo of a leaf", "a photo of a green plant",
59
- "a photo of a medicinal plant", "a photo of herbs"
60
- ]
61
- non_plant_prompts = [
62
- "a photo of a person", "a photo of food", "a document", "a vehicle"
63
- ]
64
- all_prompts = plant_prompts + non_plant_prompts
65
- tokens = clip.tokenize(all_prompts).to(clip_device)
66
-
67
- with torch.no_grad():
68
- logits, _ = clip_model(clip_image, tokens)
69
- probs = logits.softmax(dim=-1).cpu().numpy()[0]
70
-
71
- plant_conf = np.mean(probs[:len(plant_prompts)])
72
- non_plant_conf = np.mean(probs[len(plant_prompts):])
73
- return plant_conf > non_plant_conf
74
-
75
- @app.post("/predict")
76
- async def predict(file: UploadFile = File(...)):
77
- try:
78
- image_bytes = await file.read()
79
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
80
-
81
- # Image quality
82
- is_good, issues = check_image_quality(image_pil)
83
- if not is_good:
84
- return JSONResponse(status_code=400, content={"error": "Low image quality", "issues": issues})
85
-
86
- # Plant validation
87
- if not validate_plant_image(image_pil):
88
- return JSONResponse(status_code=400, content={"error": "Image does not look like a plant"})
89
-
90
- # Prepare for TFLite
91
- image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
92
- image = cv2.resize(image, (224, 224))
93
- image = image / 255.0
94
- image = image.astype(np.float32)
95
- image = np.expand_dims(image, axis=0)
96
-
97
- input_details = tflite_model.get_input_details()
98
- output_details = tflite_model.get_output_details()
99
-
100
- tflite_model.set_tensor(input_details[0]['index'], image)
101
- tflite_model.invoke()
102
- output = tflite_model.get_tensor(output_details[0]['index'])[0]
103
-
104
- pred_idx = int(np.argmax(output))
105
- pred_name = class_names[pred_idx]
106
- confidence = float(output[pred_idx])
107
- top3_idx = np.argsort(output)[-3:][::-1]
108
- top_preds = {class_names[i]: round(float(output[i]), 4) for i in top3_idx}
109
- pdf_url = plant_pdf_map.get(pred_name, "https://kampa.karnataka.gov.in/")
110
-
111
- return {
112
- "prediction": pred_name,
113
- "confidence": confidence,
114
- "top_3_predictions": top_preds,
115
- "pdf_url": pdf_url
116
- }
117
- except Exception as e:
118
- return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import io
7
+ import torch
8
+ import clip
9
+ import tensorflow as tf
10
+
11
+ app = FastAPI()
12
+
13
+ # Load models
14
+ tflite_model = tf.lite.Interpreter(model_path="resnet_model.tflite")
15
+ tflite_model.allocate_tensors()
16
+
17
+ clip_device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device)
19
+
20
+
21
+ # Class names
22
+ class_names = [
23
+ "Nooni", "Nithyapushpa", "Basale", "Pomegranate", "Honge",
24
+ "Lemon_grass", "Mint", "Betel_Nut", "Nagadali", "Curry_Leaf",
25
+ "Jasmine", "Castor", "Sapota", "Neem", "Ashoka", "Brahmi",
26
+ "Amruta_Balli", "Pappaya", "Pepper", "Wood_sorel", "Gauva",
27
+ "Hibiscus", "Ashwagandha", "Aloevera", "Raktachandini",
28
+ "Insulin", "Bamboo", "Amla", "Arali", "Geranium", "Avacado",
29
+ "Lemon", "Ekka", "Betel", "Henna", "Doddapatre", "Rose",
30
+ "Mango", "Tulasi", "Ganike"
31
+ ]
32
+
33
+ plant_pdf_map = {
34
+ "Tulasi": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/tulsi.pdf",
35
+ "Neem": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/neem.pdf",
36
+ "Mint": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/mint.pdf",
37
+ "Aloevera": "https://kampa.karnataka.gov.in/storage/pdf-files/brochure%20of%20medicinal%20plants/aloevera.pdf"
38
+ }
39
+
40
+ # Helpers
41
+ def check_image_quality(image_pil):
42
+ img_array = np.array(image_pil)
43
+ brightness = np.mean(img_array)
44
+ color_std = np.std(img_array)
45
+ too_dark = brightness < 30
46
+ too_bright = brightness > 220
47
+ low_contrast = color_std < 15
48
+ too_small = image_pil.width < 100 or image_pil.height < 100
49
+ is_good = not (too_dark or too_bright or low_contrast or too_small)
50
+ issues = []
51
+ if too_dark: issues.append("Too dark")
52
+ if too_bright: issues.append("Too bright")
53
+ if low_contrast: issues.append("Low contrast")
54
+ if too_small: issues.append("Too small")
55
+ return is_good, issues
56
+
57
+ def validate_plant_image(image_pil):
58
+ clip_image = clip_preprocess(image_pil).unsqueeze(0).to(clip_device)
59
+ plant_prompts = [
60
+ "a photo of a plant", "a photo of a leaf", "a photo of a green plant",
61
+ "a photo of a medicinal plant", "a photo of herbs"
62
+ ]
63
+ non_plant_prompts = [
64
+ "a photo of a person", "a photo of food", "a document", "a vehicle"
65
+ ]
66
+ all_prompts = plant_prompts + non_plant_prompts
67
+ tokens = clip.tokenize(all_prompts).to(clip_device)
68
+
69
+ with torch.no_grad():
70
+ logits, _ = clip_model(clip_image, tokens)
71
+ probs = logits.softmax(dim=-1).cpu().numpy()[0]
72
+
73
+ plant_conf = np.mean(probs[:len(plant_prompts)])
74
+ non_plant_conf = np.mean(probs[len(plant_prompts):])
75
+ return plant_conf > non_plant_conf
76
+
77
+ @app.post("/predict")
78
+ async def predict(file: UploadFile = File(...)):
79
+ try:
80
+ image_bytes = await file.read()
81
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
82
+
83
+ # Image quality
84
+ is_good, issues = check_image_quality(image_pil)
85
+ if not is_good:
86
+ return JSONResponse(status_code=400, content={"error": "Low image quality", "issues": issues})
87
+
88
+ # Plant validation
89
+ if not validate_plant_image(image_pil):
90
+ return JSONResponse(status_code=400, content={"error": "Image does not look like a plant"})
91
+
92
+ # Prepare for TFLite
93
+ image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
94
+ image = cv2.resize(image, (224, 224))
95
+ image = image / 255.0
96
+ image = image.astype(np.float32)
97
+ image = np.expand_dims(image, axis=0)
98
+
99
+ input_details = tflite_model.get_input_details()
100
+ output_details = tflite_model.get_output_details()
101
+
102
+ tflite_model.set_tensor(input_details[0]['index'], image)
103
+ tflite_model.invoke()
104
+ output = tflite_model.get_tensor(output_details[0]['index'])[0]
105
+
106
+ pred_idx = int(np.argmax(output))
107
+ pred_name = class_names[pred_idx]
108
+ confidence = float(output[pred_idx])
109
+ top3_idx = np.argsort(output)[-3:][::-1]
110
+ top_preds = {class_names[i]: round(float(output[i]), 4) for i in top3_idx}
111
+ pdf_url = plant_pdf_map.get(pred_name, "https://kampa.karnataka.gov.in/")
112
+
113
+ return {
114
+ "prediction": pred_name,
115
+ "confidence": confidence,
116
+ "top_3_predictions": top_preds,
117
+ "pdf_url": pdf_url
118
+ }
119
+ except Exception as e:
120
+ return JSONResponse(status_code=500, content={"error": str(e)})
121
+
122
+
123
+ if __name__ == "__main__":
124
+ import uvicorn
125
+ uvicorn.run(app, host="0.0.0.0", port=7860)