Charuka66 commited on
Commit
01efda8
Β·
verified Β·
1 Parent(s): b00dada

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -71
main.py CHANGED
@@ -1,17 +1,14 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from typing import List
4
  from sahi import AutoDetectionModel
5
  from sahi.predict import get_sliced_prediction
6
- from PIL import Image, ExifTags
7
  import uvicorn
8
  import shutil
9
  import os
10
- import io
11
 
12
  app = FastAPI()
13
 
14
- # SECURITY: Allow access
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -20,14 +17,12 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- # ==========================================
24
- # 🧠 LOAD SAHI MODEL
25
- # ==========================================
26
  print("⏳ Loading SAHI + YOLO Model...")
27
  try:
28
  detection_model = AutoDetectionModel.from_pretrained(
29
  model_type='yolov8',
30
- model_path='best.pt', # Your trained model
31
  confidence_threshold=0.25,
32
  device='cpu'
33
  )
@@ -35,53 +30,7 @@ try:
35
  except Exception as e:
36
  print(f"❌ Error loading model: {e}")
37
 
38
- # ==========================================
39
- # πŸ“ GPS HELPER FUNCTIONS (NEW)
40
- # ==========================================
41
- def get_decimal_from_dms(dms, ref):
42
- degrees = dms[0]
43
- minutes = dms[1]
44
- seconds = dms[2]
45
- decimal = degrees + (minutes / 60.0) + (seconds / 3600.0)
46
- if ref in ['S', 'W']:
47
- decimal = -decimal
48
- return decimal
49
-
50
- def get_image_coordinates(image):
51
- try:
52
- exif_data = image._getexif()
53
- if not exif_data:
54
- return None, None
55
-
56
- gps_info = {}
57
- for tag, value in exif_data.items():
58
- decoded = ExifTags.TAGS.get(tag, tag)
59
- if decoded == "GPSInfo":
60
- gps_info = value
61
- break
62
-
63
- if not gps_info:
64
- return None, None
65
-
66
- # Extract Lat/Lon
67
- lat_dms = gps_info.get(2) # GPSLatitude
68
- lat_ref = gps_info.get(1) # GPSLatitudeRef
69
- lon_dms = gps_info.get(4) # GPSLongitude
70
- lon_ref = gps_info.get(3) # GPSLongitudeRef
71
-
72
- if lat_dms and lat_ref and lon_dms and lon_ref:
73
- lat = get_decimal_from_dms(lat_dms, lat_ref)
74
- lon = get_decimal_from_dms(lon_dms, lon_ref)
75
- return lat, lon
76
-
77
- except Exception as e:
78
- print(f"⚠️ GPS Extract Error: {e}")
79
- return None, None
80
-
81
- return None, None
82
-
83
-
84
- # Helper function for recommendations
85
  def get_recommendation(disease_name):
86
  recommendations = {
87
  "Blast": "Use Tricyclazole 75 WP. Avoid excess nitrogen.",
@@ -95,28 +44,25 @@ def get_recommendation(disease_name):
95
 
96
  @app.get("/")
97
  def home():
98
- return {"message": "Goyam AI (SAHI Enabled) is Running! πŸš€"}
99
 
100
  # ---------------------------------------------------------
101
- # 1. SINGLE IMAGE PREDICTION (Updated for GPS)
102
  # ---------------------------------------------------------
103
  @app.post("/predict")
104
- async def predict(file: UploadFile = File(...)):
 
 
 
 
105
  print(f"πŸ“₯ Receiving image: {file.filename}")
 
106
 
107
- # Read file into memory to extract GPS + Save temp
108
- image_bytes = await file.read()
109
- image_pil = Image.open(io.BytesIO(image_bytes))
110
-
111
- # πŸ“ Extract GPS
112
- lat, lon = get_image_coordinates(image_pil)
113
- print(f"πŸ“ Found Location: {lat}, {lon}")
114
-
115
- # Save temp file for SAHI
116
  temp_filename = f"temp_{file.filename}"
117
  try:
 
118
  with open(temp_filename, "wb") as buffer:
119
- buffer.write(image_bytes)
120
 
121
  # Run SAHI
122
  result = get_sliced_prediction(
@@ -130,10 +76,11 @@ async def predict(file: UploadFile = File(...)):
130
 
131
  predictions = result.object_prediction_list
132
 
 
133
  response_data = {
134
  "filename": file.filename,
135
- "latitude": lat, # πŸ‘ˆ Sending GPS back
136
- "longitude": lon # πŸ‘ˆ Sending GPS back
137
  }
138
 
139
  if len(predictions) > 0:
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import List, Optional
4
  from sahi import AutoDetectionModel
5
  from sahi.predict import get_sliced_prediction
 
6
  import uvicorn
7
  import shutil
8
  import os
 
9
 
10
  app = FastAPI()
11
 
 
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
17
  allow_headers=["*"],
18
  )
19
 
20
+ # 🧠 LOAD MODEL
 
 
21
  print("⏳ Loading SAHI + YOLO Model...")
22
  try:
23
  detection_model = AutoDetectionModel.from_pretrained(
24
  model_type='yolov8',
25
+ model_path='best.pt',
26
  confidence_threshold=0.25,
27
  device='cpu'
28
  )
 
30
  except Exception as e:
31
  print(f"❌ Error loading model: {e}")
32
 
33
+ # Helper for recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_recommendation(disease_name):
35
  recommendations = {
36
  "Blast": "Use Tricyclazole 75 WP. Avoid excess nitrogen.",
 
44
 
45
  @app.get("/")
46
  def home():
47
+ return {"message": "Goyam AI is Running! πŸš€"}
48
 
49
  # ---------------------------------------------------------
50
+ # 1. PREDICT ENDPOINT (Now accepts explicit GPS)
51
  # ---------------------------------------------------------
52
  @app.post("/predict")
53
+ async def predict(
54
+ file: UploadFile = File(...),
55
+ latitude: Optional[str] = Form(None), # πŸ‘ˆ New text field
56
+ longitude: Optional[str] = Form(None) # πŸ‘ˆ New text field
57
+ ):
58
  print(f"πŸ“₯ Receiving image: {file.filename}")
59
+ print(f"πŸ“ Received GPS: {latitude}, {longitude}")
60
 
 
 
 
 
 
 
 
 
 
61
  temp_filename = f"temp_{file.filename}"
62
  try:
63
+ # Save temp file
64
  with open(temp_filename, "wb") as buffer:
65
+ shutil.copyfileobj(file.file, buffer)
66
 
67
  # Run SAHI
68
  result = get_sliced_prediction(
 
76
 
77
  predictions = result.object_prediction_list
78
 
79
+ # Prepare Response
80
  response_data = {
81
  "filename": file.filename,
82
+ "latitude": float(latitude) if latitude else None,
83
+ "longitude": float(longitude) if longitude else None
84
  }
85
 
86
  if len(predictions) > 0: