httpsAkayush commited on
Commit
b46325f
·
verified ·
1 Parent(s): 3d8dd57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -50
app.py CHANGED
@@ -1,27 +1,17 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
  import tensorflow as tf
4
  import numpy as np
5
  from PIL import Image
6
- import io
7
- import uvicorn
8
- import tempfile
9
  import cv2
10
 
11
- # Initialize FastAPI app
12
- app = FastAPI(title="Plant Disease Detection API", version="1.0.0")
13
 
14
- # Add CORS middleware to allow requests from your frontend
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"], # In production, replace with your frontend URL
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
 
23
  # Load your model
24
- model = tf.keras.models.load_model('trained_modela.keras')
25
 
26
  # Define your class names (update with your actual classes)
27
  class_name = ['Apple___Apple_scab',
@@ -63,32 +53,20 @@ class_name = ['Apple___Apple_scab',
63
  'Tomato___Tomato_mosaic_virus',
64
  'Tomato___healthy']
65
 
66
-
67
- @app.get("/")
68
- async def root():
69
- return {"message": "Plant Disease Detection API", "version": "1.0.0"}
70
-
71
- @app.post("/predict")
72
- async def predict_disease(file: UploadFile = File(...)):
73
  """
74
- Predict plant disease from uploaded image
75
  """
76
  try:
77
- # Validate file type
78
- # Validate file type
79
- if not file.content_type.startswith('image/'):
80
- raise HTTPException(status_code=400, detail="File must be an image")
81
 
82
- # Save uploaded file temporarily
83
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
84
  temp_path = tmp.name
85
- contents = await file.read()
86
- tmp.write(contents)
87
 
88
  # Read image using OpenCV
89
  img = cv2.imread(temp_path)
90
- if img is None:
91
- raise HTTPException(status_code=400, detail="Invalid image file")
92
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
93
 
94
  image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128))
@@ -102,25 +80,22 @@ async def predict_disease(file: UploadFile = File(...)):
102
  confidence = prediction[0][result_index]
103
  disease_name = class_name[result_index]
104
 
105
- return {
106
- "success": True,
107
- "disease": disease_name,
108
- "confidence": confidence
109
- }
110
 
111
- except HTTPException as he:
112
- raise he
113
  except Exception as e:
114
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
115
 
116
- @app.get("/health")
117
- async def health_check():
118
- return {"status": "healthy"}
119
-
120
- @app.get("/classes")
121
- async def get_classes():
122
- """Get all available disease classes"""
123
- return {"classes": class_name}
 
 
 
124
 
125
  if __name__ == "__main__":
126
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import tempfile
 
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
+ import gradio as gr
6
+ import requests
7
+ from io import BytesIO
8
  import cv2
9
 
 
 
10
 
11
+ model = tf.keras.models.load_model('trained_modela.keras')
 
 
 
 
 
 
 
12
 
13
  # Load your model
14
+
15
 
16
  # Define your class names (update with your actual classes)
17
  class_name = ['Apple___Apple_scab',
 
53
  'Tomato___Tomato_mosaic_virus',
54
  'Tomato___healthy']
55
 
56
+ def predict_disease(image):
 
 
 
 
 
 
57
  """
58
+ Predict plant disease from uploaded image using same preprocessing as your working cv2 method
59
  """
60
  try:
 
 
 
 
61
 
 
62
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
63
  temp_path = tmp.name
64
+ image.save(temp_path)
65
+
66
 
67
  # Read image using OpenCV
68
  img = cv2.imread(temp_path)
69
+
 
70
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
71
 
72
  image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128))
 
80
  confidence = prediction[0][result_index]
81
  disease_name = class_name[result_index]
82
 
83
+ return f"Disease: {disease_name}\nConfidence: {confidence:.2%}"
 
 
 
 
84
 
 
 
85
  except Exception as e:
86
+ return f"Error: {str(e)}"
87
 
88
+ # Create Gradio interface
89
+ iface = gr.Interface(
90
+ fn=predict_disease,
91
+ inputs=gr.Image(type="pil", label="Upload Plant Image"),
92
+ outputs=gr.Textbox(label="Prediction Result"),
93
+ title="Plant Disease Detection API",
94
+ description="Upload an image of a plant leaf to detect diseases",
95
+ examples=[
96
+ # You can add example images here
97
+ ]
98
+ )
99
 
100
  if __name__ == "__main__":
101
+ iface.launch()