Spaces:
Sleeping
Sleeping
Update max.py
Browse files
max.py
CHANGED
|
@@ -5,8 +5,10 @@ import re
|
|
| 5 |
import logging
|
| 6 |
import tempfile
|
| 7 |
import base64
|
|
|
|
| 8 |
from uuid import uuid4
|
| 9 |
from typing import Optional, List
|
|
|
|
| 10 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -245,6 +247,9 @@ class PredictionResponse(BaseModel):
|
|
| 245 |
category: str
|
| 246 |
output_image: Optional[str] = None # Base64-encoded output image
|
| 247 |
|
|
|
|
|
|
|
|
|
|
| 248 |
class QuestionRequest(BaseModel):
|
| 249 |
session_id: str
|
| 250 |
question: str
|
|
@@ -293,6 +298,69 @@ food_categories = {
|
|
| 293 |
"Sauce Condiments and Seasonings": ['guacamole' ],
|
| 294 |
}
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
# Image classification endpoints
|
| 297 |
@app.post(
|
| 298 |
"/predict",
|
|
@@ -355,6 +423,70 @@ async def predict_image(file: UploadFile = File(..., description="A food image i
|
|
| 355 |
logger.info(f"Prediction for {file.filename}: {label} (Category: {category})")
|
| 356 |
return PredictionResponse(label=label, category=category, output_image=output_image)
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
# E-commerce assistant endpoints
|
| 359 |
def execute_function_call(raw_response: str) -> str:
|
| 360 |
think_end = raw_response.find('</think>')
|
|
@@ -492,6 +624,7 @@ async def root():
|
|
| 492 |
"message": "Welcome to the EcoHarvest Combined API",
|
| 493 |
"endpoints": {
|
| 494 |
"image_classification": "/predict",
|
|
|
|
| 495 |
"ecommerce_assistant": "/ask",
|
| 496 |
"food_categories": "/food/categories",
|
| 497 |
"docs": "/docs"
|
|
|
|
| 5 |
import logging
|
| 6 |
import tempfile
|
| 7 |
import base64
|
| 8 |
+
import requests
|
| 9 |
from uuid import uuid4
|
| 10 |
from typing import Optional, List
|
| 11 |
+
from urllib.parse import urlparse
|
| 12 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 13 |
from fastapi.responses import JSONResponse
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 247 |
category: str
|
| 248 |
output_image: Optional[str] = None # Base64-encoded output image
|
| 249 |
|
| 250 |
+
class ImageUrlRequest(BaseModel):
|
| 251 |
+
image_url: str
|
| 252 |
+
|
| 253 |
class QuestionRequest(BaseModel):
|
| 254 |
session_id: str
|
| 255 |
question: str
|
|
|
|
| 298 |
"Sauce Condiments and Seasonings": ['guacamole' ],
|
| 299 |
}
|
| 300 |
|
| 301 |
+
# Helper functions for URL processing
|
| 302 |
+
def convert_google_drive_url(url: str) -> str:
|
| 303 |
+
"""Convert Google Drive share URL to direct download URL."""
|
| 304 |
+
if 'drive.google.com' in url:
|
| 305 |
+
if '/file/d/' in url:
|
| 306 |
+
# Extract file ID from share URL
|
| 307 |
+
match = re.search(r'/file/d/([a-zA-Z0-9-_]+)', url)
|
| 308 |
+
if match:
|
| 309 |
+
file_id = match.group(1)
|
| 310 |
+
return f"https://drive.google.com/uc?id={file_id}&export=download"
|
| 311 |
+
elif '/uc?id=' in url:
|
| 312 |
+
# Already a direct download URL, ensure it has export=download
|
| 313 |
+
if 'export=download' not in url:
|
| 314 |
+
if '&' in url:
|
| 315 |
+
return url + "&export=download"
|
| 316 |
+
else:
|
| 317 |
+
return url + "?export=download"
|
| 318 |
+
return url
|
| 319 |
+
elif 'drive.usercontent.google.com' in url:
|
| 320 |
+
# Handle usercontent URLs
|
| 321 |
+
return url.replace('export=view', 'export=download')
|
| 322 |
+
|
| 323 |
+
# Return original URL if not a Google Drive URL
|
| 324 |
+
return url
|
| 325 |
+
|
| 326 |
+
def download_image_from_url(url: str) -> Image.Image:
|
| 327 |
+
"""Download image from URL and return PIL Image object."""
|
| 328 |
+
try:
|
| 329 |
+
# Set headers to mimic a browser request
|
| 330 |
+
headers = {
|
| 331 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
| 332 |
+
'Accept': 'image/webp,image/apng,image/*,*/*;q=0.8',
|
| 333 |
+
'Accept-Language': 'en-US,en;q=0.9',
|
| 334 |
+
'Accept-Encoding': 'gzip, deflate, br',
|
| 335 |
+
'Connection': 'keep-alive',
|
| 336 |
+
'Upgrade-Insecure-Requests': '1',
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Download the image with a timeout
|
| 340 |
+
response = requests.get(url, headers=headers, timeout=30, stream=True)
|
| 341 |
+
response.raise_for_status()
|
| 342 |
+
|
| 343 |
+
# Check content type
|
| 344 |
+
content_type = response.headers.get('content-type', '').lower()
|
| 345 |
+
if not content_type.startswith('image/'):
|
| 346 |
+
# Sometimes Google Drive doesn't return proper content-type, so check content
|
| 347 |
+
if len(response.content) < 1000 and b'html' in response.content.lower():
|
| 348 |
+
raise ValueError(f"URL returned HTML content, not an image. Please ensure the Google Drive file is publicly accessible.")
|
| 349 |
+
|
| 350 |
+
# Open image from response content
|
| 351 |
+
image = Image.open(io.BytesIO(response.content)).convert("RGB")
|
| 352 |
+
return image
|
| 353 |
+
|
| 354 |
+
except requests.exceptions.RequestException as e:
|
| 355 |
+
logger.error(f"Failed to download image from URL {url}: {str(e)}")
|
| 356 |
+
raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}. Please ensure the Google Drive link is publicly accessible.")
|
| 357 |
+
except UnidentifiedImageError:
|
| 358 |
+
logger.error(f"Invalid image file from URL {url}")
|
| 359 |
+
raise HTTPException(status_code=400, detail="URL does not contain a valid image file")
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.error(f"Error processing image from URL {url}: {str(e)}")
|
| 362 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 363 |
+
|
| 364 |
# Image classification endpoints
|
| 365 |
@app.post(
|
| 366 |
"/predict",
|
|
|
|
| 423 |
logger.info(f"Prediction for {file.filename}: {label} (Category: {category})")
|
| 424 |
return PredictionResponse(label=label, category=category, output_image=output_image)
|
| 425 |
|
| 426 |
+
@app.post(
|
| 427 |
+
"/predict-from-url",
|
| 428 |
+
response_model=PredictionResponse,
|
| 429 |
+
summary="Classify a food image from URL",
|
| 430 |
+
description="Provide an image URL (preferably Google Drive) to classify it into one of 101 food categories and return its category."
|
| 431 |
+
)
|
| 432 |
+
async def predict_image_from_url(request: ImageUrlRequest):
|
| 433 |
+
image_url = request.image_url.strip()
|
| 434 |
+
logger.info(f"Received image URL for prediction: {image_url}")
|
| 435 |
+
|
| 436 |
+
if not image_url:
|
| 437 |
+
raise HTTPException(status_code=400, detail="Image URL is required")
|
| 438 |
+
|
| 439 |
+
# Validate URL format
|
| 440 |
+
try:
|
| 441 |
+
parsed_url = urlparse(image_url)
|
| 442 |
+
if not parsed_url.scheme or not parsed_url.netloc:
|
| 443 |
+
raise HTTPException(status_code=400, detail="Invalid URL format")
|
| 444 |
+
except Exception:
|
| 445 |
+
raise HTTPException(status_code=400, detail="Invalid URL format")
|
| 446 |
+
|
| 447 |
+
# Convert Google Drive URL to direct download URL if needed
|
| 448 |
+
direct_url = convert_google_drive_url(image_url)
|
| 449 |
+
logger.info(f"Converted URL: {direct_url}")
|
| 450 |
+
|
| 451 |
+
# Download and process image
|
| 452 |
+
try:
|
| 453 |
+
image = download_image_from_url(direct_url)
|
| 454 |
+
except HTTPException:
|
| 455 |
+
raise
|
| 456 |
+
except Exception as e:
|
| 457 |
+
logger.error(f"Unexpected error downloading image: {str(e)}")
|
| 458 |
+
raise HTTPException(status_code=500, detail="Failed to process image from URL")
|
| 459 |
+
|
| 460 |
+
# Predict using the downloaded image
|
| 461 |
+
try:
|
| 462 |
+
# Create a temporary file for prediction
|
| 463 |
+
fd, temp_file_path = tempfile.mkstemp(suffix=".jpg")
|
| 464 |
+
try:
|
| 465 |
+
image.save(temp_file_path)
|
| 466 |
+
label, output_image_path = classifier.predict(temp_file_path)
|
| 467 |
+
finally:
|
| 468 |
+
os.close(fd)
|
| 469 |
+
if os.path.exists(temp_file_path):
|
| 470 |
+
os.remove(temp_file_path) # Clean up
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.error(f"Prediction error: {str(e)}")
|
| 473 |
+
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
|
| 474 |
+
|
| 475 |
+
# Determine category
|
| 476 |
+
category = next((cat for cat, foods in food_categories.items() if label in foods), "Uncategorized")
|
| 477 |
+
|
| 478 |
+
# Encode output image as base64 if available
|
| 479 |
+
output_image = None
|
| 480 |
+
if output_image_path and os.path.exists(output_image_path):
|
| 481 |
+
try:
|
| 482 |
+
with open(output_image_path, "rb") as f:
|
| 483 |
+
output_image = base64.b64encode(f.read()).decode("utf-8")
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logger.warning(f"Failed to encode output image: {str(e)}")
|
| 486 |
+
|
| 487 |
+
logger.info(f"Prediction for URL {image_url}: {label} (Category: {category})")
|
| 488 |
+
return PredictionResponse(label=label, category=category, output_image=output_image)
|
| 489 |
+
|
| 490 |
# E-commerce assistant endpoints
|
| 491 |
def execute_function_call(raw_response: str) -> str:
|
| 492 |
think_end = raw_response.find('</think>')
|
|
|
|
| 624 |
"message": "Welcome to the EcoHarvest Combined API",
|
| 625 |
"endpoints": {
|
| 626 |
"image_classification": "/predict",
|
| 627 |
+
"image_classification_from_url": "/predict-from-url",
|
| 628 |
"ecommerce_assistant": "/ask",
|
| 629 |
"food_categories": "/food/categories",
|
| 630 |
"docs": "/docs"
|