nivakaran commited on
Commit
f18a161
·
verified ·
1 Parent(s): 6a45ca6

Update max.py

Browse files
Files changed (1) hide show
  1. max.py +133 -0
max.py CHANGED
@@ -5,8 +5,10 @@ import re
5
  import logging
6
  import tempfile
7
  import base64
 
8
  from uuid import uuid4
9
  from typing import Optional, List
 
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.responses import JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
@@ -245,6 +247,9 @@ class PredictionResponse(BaseModel):
245
  category: str
246
  output_image: Optional[str] = None # Base64-encoded output image
247
 
 
 
 
248
  class QuestionRequest(BaseModel):
249
  session_id: str
250
  question: str
@@ -293,6 +298,69 @@ food_categories = {
293
  "Sauce Condiments and Seasonings": ['guacamole' ],
294
  }
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # Image classification endpoints
297
  @app.post(
298
  "/predict",
@@ -355,6 +423,70 @@ async def predict_image(file: UploadFile = File(..., description="A food image i
355
  logger.info(f"Prediction for {file.filename}: {label} (Category: {category})")
356
  return PredictionResponse(label=label, category=category, output_image=output_image)
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  # E-commerce assistant endpoints
359
  def execute_function_call(raw_response: str) -> str:
360
  think_end = raw_response.find('</think>')
@@ -492,6 +624,7 @@ async def root():
492
  "message": "Welcome to the EcoHarvest Combined API",
493
  "endpoints": {
494
  "image_classification": "/predict",
 
495
  "ecommerce_assistant": "/ask",
496
  "food_categories": "/food/categories",
497
  "docs": "/docs"
 
5
  import logging
6
  import tempfile
7
  import base64
8
+ import requests
9
  from uuid import uuid4
10
  from typing import Optional, List
11
+ from urllib.parse import urlparse
12
  from fastapi import FastAPI, UploadFile, File, HTTPException
13
  from fastapi.responses import JSONResponse
14
  from fastapi.middleware.cors import CORSMiddleware
 
247
  category: str
248
  output_image: Optional[str] = None # Base64-encoded output image
249
 
250
+ class ImageUrlRequest(BaseModel):
251
+ image_url: str
252
+
253
  class QuestionRequest(BaseModel):
254
  session_id: str
255
  question: str
 
298
  "Sauce Condiments and Seasonings": ['guacamole' ],
299
  }
300
 
301
+ # Helper functions for URL processing
302
+ def convert_google_drive_url(url: str) -> str:
303
+ """Convert Google Drive share URL to direct download URL."""
304
+ if 'drive.google.com' in url:
305
+ if '/file/d/' in url:
306
+ # Extract file ID from share URL
307
+ match = re.search(r'/file/d/([a-zA-Z0-9-_]+)', url)
308
+ if match:
309
+ file_id = match.group(1)
310
+ return f"https://drive.google.com/uc?id={file_id}&export=download"
311
+ elif '/uc?id=' in url:
312
+ # Already a direct download URL, ensure it has export=download
313
+ if 'export=download' not in url:
314
+ if '&' in url:
315
+ return url + "&export=download"
316
+ else:
317
+ return url + "?export=download"
318
+ return url
319
+ elif 'drive.usercontent.google.com' in url:
320
+ # Handle usercontent URLs
321
+ return url.replace('export=view', 'export=download')
322
+
323
+ # Return original URL if not a Google Drive URL
324
+ return url
325
+
326
+ def download_image_from_url(url: str) -> Image.Image:
327
+ """Download image from URL and return PIL Image object."""
328
+ try:
329
+ # Set headers to mimic a browser request
330
+ headers = {
331
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
332
+ 'Accept': 'image/webp,image/apng,image/*,*/*;q=0.8',
333
+ 'Accept-Language': 'en-US,en;q=0.9',
334
+ 'Accept-Encoding': 'gzip, deflate, br',
335
+ 'Connection': 'keep-alive',
336
+ 'Upgrade-Insecure-Requests': '1',
337
+ }
338
+
339
+ # Download the image with a timeout
340
+ response = requests.get(url, headers=headers, timeout=30, stream=True)
341
+ response.raise_for_status()
342
+
343
+ # Check content type
344
+ content_type = response.headers.get('content-type', '').lower()
345
+ if not content_type.startswith('image/'):
346
+ # Sometimes Google Drive doesn't return proper content-type, so check content
347
+ if len(response.content) < 1000 and b'html' in response.content.lower():
348
+ raise ValueError(f"URL returned HTML content, not an image. Please ensure the Google Drive file is publicly accessible.")
349
+
350
+ # Open image from response content
351
+ image = Image.open(io.BytesIO(response.content)).convert("RGB")
352
+ return image
353
+
354
+ except requests.exceptions.RequestException as e:
355
+ logger.error(f"Failed to download image from URL {url}: {str(e)}")
356
+ raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}. Please ensure the Google Drive link is publicly accessible.")
357
+ except UnidentifiedImageError:
358
+ logger.error(f"Invalid image file from URL {url}")
359
+ raise HTTPException(status_code=400, detail="URL does not contain a valid image file")
360
+ except Exception as e:
361
+ logger.error(f"Error processing image from URL {url}: {str(e)}")
362
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
363
+
364
  # Image classification endpoints
365
  @app.post(
366
  "/predict",
 
423
  logger.info(f"Prediction for {file.filename}: {label} (Category: {category})")
424
  return PredictionResponse(label=label, category=category, output_image=output_image)
425
 
426
+ @app.post(
427
+ "/predict-from-url",
428
+ response_model=PredictionResponse,
429
+ summary="Classify a food image from URL",
430
+ description="Provide an image URL (preferably Google Drive) to classify it into one of 101 food categories and return its category."
431
+ )
432
+ async def predict_image_from_url(request: ImageUrlRequest):
433
+ image_url = request.image_url.strip()
434
+ logger.info(f"Received image URL for prediction: {image_url}")
435
+
436
+ if not image_url:
437
+ raise HTTPException(status_code=400, detail="Image URL is required")
438
+
439
+ # Validate URL format
440
+ try:
441
+ parsed_url = urlparse(image_url)
442
+ if not parsed_url.scheme or not parsed_url.netloc:
443
+ raise HTTPException(status_code=400, detail="Invalid URL format")
444
+ except Exception:
445
+ raise HTTPException(status_code=400, detail="Invalid URL format")
446
+
447
+ # Convert Google Drive URL to direct download URL if needed
448
+ direct_url = convert_google_drive_url(image_url)
449
+ logger.info(f"Converted URL: {direct_url}")
450
+
451
+ # Download and process image
452
+ try:
453
+ image = download_image_from_url(direct_url)
454
+ except HTTPException:
455
+ raise
456
+ except Exception as e:
457
+ logger.error(f"Unexpected error downloading image: {str(e)}")
458
+ raise HTTPException(status_code=500, detail="Failed to process image from URL")
459
+
460
+ # Predict using the downloaded image
461
+ try:
462
+ # Create a temporary file for prediction
463
+ fd, temp_file_path = tempfile.mkstemp(suffix=".jpg")
464
+ try:
465
+ image.save(temp_file_path)
466
+ label, output_image_path = classifier.predict(temp_file_path)
467
+ finally:
468
+ os.close(fd)
469
+ if os.path.exists(temp_file_path):
470
+ os.remove(temp_file_path) # Clean up
471
+ except Exception as e:
472
+ logger.error(f"Prediction error: {str(e)}")
473
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
474
+
475
+ # Determine category
476
+ category = next((cat for cat, foods in food_categories.items() if label in foods), "Uncategorized")
477
+
478
+ # Encode output image as base64 if available
479
+ output_image = None
480
+ if output_image_path and os.path.exists(output_image_path):
481
+ try:
482
+ with open(output_image_path, "rb") as f:
483
+ output_image = base64.b64encode(f.read()).decode("utf-8")
484
+ except Exception as e:
485
+ logger.warning(f"Failed to encode output image: {str(e)}")
486
+
487
+ logger.info(f"Prediction for URL {image_url}: {label} (Category: {category})")
488
+ return PredictionResponse(label=label, category=category, output_image=output_image)
489
+
490
  # E-commerce assistant endpoints
491
  def execute_function_call(raw_response: str) -> str:
492
  think_end = raw_response.find('</think>')
 
624
  "message": "Welcome to the EcoHarvest Combined API",
625
  "endpoints": {
626
  "image_classification": "/predict",
627
+ "image_classification_from_url": "/predict-from-url",
628
  "ecommerce_assistant": "/ask",
629
  "food_categories": "/food/categories",
630
  "docs": "/docs"