|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.exceptions import RequestValidationError |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from typing import List |
|
|
import asyncio |
|
|
import os |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from src.Generate_caption import load_model_from_path, tokenizer_load |
|
|
from src.Color_extraction import extract_colors |
|
|
from src.Generate_productName_description import generate_product_name, generate_description, clean_response |
|
|
from huggingface_hub import hf_hub_download |
|
|
import tempfile |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
API_KEY = os.environ.get("APIKey") |
|
|
|
|
|
if not API_KEY: |
|
|
print(API_KEY) |
|
|
raise ValueError("API_KEY not set. Please configure your .env file or system environment.") |
|
|
|
|
|
|
|
|
vgg16_model = None |
|
|
fifth_version_model = None |
|
|
tokenizer = None |
|
|
executor = ThreadPoolExecutor(max_workers=4) |
|
|
|
|
|
|
|
|
HF_CACHE_DIR = "/app/hf_models_cache" |
|
|
os.makedirs(HF_CACHE_DIR, exist_ok=True) |
|
|
|
|
|
os.environ["XDG_CACHE_HOME"] = "/app/onnx_cache" |
|
|
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True) |
|
|
|
|
|
async def download_model_from_hf(repo_id: str, filename: str) -> str: |
|
|
try: |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
cache_dir=HF_CACHE_DIR, |
|
|
local_dir=HF_CACHE_DIR, |
|
|
force_download=False |
|
|
) |
|
|
print(f"Using model {filename} from {model_path}") |
|
|
return model_path |
|
|
except Exception as e: |
|
|
print(f"Error downloading/finding {filename}: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
async def load_models(): |
|
|
global vgg16_model, fifth_version_model, tokenizer |
|
|
if not all([vgg16_model, fifth_version_model, tokenizer]): |
|
|
print("Downloading and loading models from Hugging Face Hub...") |
|
|
|
|
|
try: |
|
|
|
|
|
vgg16_path, model_path, tokenizer_path = await asyncio.gather( |
|
|
download_model_from_hf("abdallah-03/AI_product_helper_models", "vgg16_feature_extractor.keras"), |
|
|
download_model_from_hf("abdallah-03/AI_product_helper_models", "fifth_version_model.keras"), |
|
|
download_model_from_hf("abdallah-03/AI_product_helper_models", "tokenizer.pkl") |
|
|
) |
|
|
|
|
|
|
|
|
vgg16_task = asyncio.to_thread(load_model_from_path, vgg16_path) |
|
|
fifth_version_task = asyncio.to_thread(load_model_from_path, model_path) |
|
|
tokenizer_task = asyncio.to_thread(tokenizer_load, tokenizer_path) |
|
|
|
|
|
vgg16_model, fifth_version_model, tokenizer = await asyncio.gather( |
|
|
vgg16_task, fifth_version_task, tokenizer_task |
|
|
) |
|
|
print("Models loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading models: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
asyncio.create_task(load_models()) |
|
|
|
|
|
|
|
|
|
|
|
class ImagePathsRequest(BaseModel): |
|
|
image_paths: List[str] |
|
|
|
|
|
|
|
|
class GenerateProductRequest(ImagePathsRequest): |
|
|
Brand_name: str |
|
|
|
|
|
|
|
|
class GenerateDescriptionRequest(BaseModel): |
|
|
product_name: str |
|
|
|
|
|
|
|
|
class AIproducthelper(ImagePathsRequest): |
|
|
Brand_name: str |
|
|
|
|
|
|
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def global_exception_handler(request: Request, exc: Exception): |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"success": False, "message": "Internal Server Error", "error": repr(exc)}, |
|
|
) |
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException) |
|
|
async def http_exception_handler(request: Request, exc: HTTPException): |
|
|
return JSONResponse( |
|
|
status_code=exc.status_code, |
|
|
content={"success": False, "message": exc.detail}, |
|
|
) |
|
|
|
|
|
|
|
|
@app.exception_handler(RequestValidationError) |
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError): |
|
|
return JSONResponse( |
|
|
status_code=422, |
|
|
content={"success": False, "message": "Validation Error", "errors": exc.errors()}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"message": "Hello from our API, models are loading in the background!"} |
|
|
|
|
|
|
|
|
@app.get("/status/") |
|
|
async def check_status(): |
|
|
if all([vgg16_model, fifth_version_model, tokenizer]): |
|
|
return { |
|
|
"success": True, |
|
|
"message": "Models are ready!", |
|
|
"models_loaded": { |
|
|
"vgg16": vgg16_model is not None, |
|
|
"fifth_version": fifth_version_model is not None, |
|
|
"tokenizer": tokenizer is not None |
|
|
} |
|
|
} |
|
|
return { |
|
|
"success": False, |
|
|
"message": "Models are still loading...", |
|
|
"models_loaded": { |
|
|
"vgg16": vgg16_model is not None, |
|
|
"fifth_version": fifth_version_model is not None, |
|
|
"tokenizer": tokenizer is not None |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/extract-colors/") |
|
|
async def extract_colors_endpoint(request: ImagePathsRequest): |
|
|
if not request.image_paths: |
|
|
raise HTTPException(status_code=400, detail="Image list cannot be empty.") |
|
|
|
|
|
try: |
|
|
colors = await asyncio.get_event_loop().run_in_executor(executor, extract_colors, request.image_paths) |
|
|
return {"success": True, "colors": colors} |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=500, detail=f"Error extracting colors: {repr(exc)}") |
|
|
|
|
|
|
|
|
@app.post("/generate-product-name/") |
|
|
async def generate_product_name_endpoint(request: GenerateProductRequest): |
|
|
if not request.image_paths: |
|
|
raise HTTPException(status_code=400, detail="Image list cannot be empty.") |
|
|
|
|
|
try: |
|
|
product_name = await asyncio.get_event_loop().run_in_executor( |
|
|
executor, generate_product_name, request.image_paths, request.Brand_name, |
|
|
vgg16_model, fifth_version_model, tokenizer, API_KEY |
|
|
) |
|
|
return {"success": True, "product_name": product_name} |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating product name: {repr(exc)}") |
|
|
|
|
|
|
|
|
@app.post("/generate-description/") |
|
|
async def generate_description_endpoint(request: GenerateDescriptionRequest): |
|
|
try: |
|
|
description = await asyncio.get_event_loop().run_in_executor( |
|
|
executor, generate_description, API_KEY, request.product_name, |
|
|
vgg16_model, fifth_version_model, tokenizer |
|
|
) |
|
|
return {"success": True, "description": description} |
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating description: {repr(exc)}") |
|
|
|
|
|
|
|
|
@app.post("/AI-product_help/") |
|
|
async def ai_product_help_endpoint(request: AIproducthelper): |
|
|
if not request.image_paths: |
|
|
raise HTTPException(status_code=400, detail="Image list cannot be empty.") |
|
|
|
|
|
try: |
|
|
product_name = await asyncio.get_event_loop().run_in_executor( |
|
|
executor, generate_product_name, request.image_paths, request.Brand_name, |
|
|
vgg16_model, fifth_version_model, tokenizer, API_KEY |
|
|
) |
|
|
product_name = clean_response(product_name) |
|
|
|
|
|
description = await asyncio.get_event_loop().run_in_executor( |
|
|
executor, generate_description, API_KEY, product_name, |
|
|
vgg16_model, fifth_version_model, tokenizer |
|
|
) |
|
|
description = clean_response(description) |
|
|
|
|
|
return {"success": True, "product_name": product_name, "description": description} |
|
|
|
|
|
except Exception as exc: |
|
|
raise HTTPException(status_code=500, detail=f"Error in AI product helper: {repr(exc)}") |