Robaa commited on
Commit
2f62519
·
verified ·
1 Parent(s): e09d4fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.exceptions import RequestValidationError
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import List
7
+ import asyncio
8
+ import os
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from dotenv import load_dotenv
11
+ from Generate_caption import load_model_from_path, tokenizer_load
12
+ from Color_extraction import extract_colors
13
+ from Generate_productName_description import generate_product_name, generate_description, clean_response
14
+ from huggingface_hub import hf_hub_download
15
+ import tempfile
16
+
17
+ app = FastAPI()
18
+
19
+ # CORS Middleware
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["http://localhost:3000"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # Load environment variables
29
+ load_dotenv()
30
+ API_KEY = os.getenv("API_KEY")
31
+ if not API_KEY:
32
+ raise ValueError("API_KEY not set. Please configure your .env file or system environment.")
33
+
34
+ # Global variables for models and ThreadPool
35
+ vgg16_model = None
36
+ fifth_version_model = None
37
+ tokenizer = None
38
+ executor = ThreadPoolExecutor(max_workers=4)
39
+
40
+ # Ensure ONNX model path is set
41
+ os.environ["XDG_CACHE_HOME"] = "models/u2net.onnx"
42
+
43
+ async def download_model_from_hf(repo_id: str, filename: str) -> str:
44
+ try:
45
+ # Create a temporary directory for model files
46
+ model_dir = os.path.join(tempfile.gettempdir(), "hf_models")
47
+ os.makedirs(model_dir, exist_ok=True)
48
+
49
+ # Download model
50
+ model_path = hf_hub_download(
51
+ repo_id=repo_id,
52
+ filename=filename,
53
+ cache_dir=model_dir,
54
+ local_dir=model_dir,
55
+ force_download=True
56
+ )
57
+ print(f"Downloaded {filename} to {model_path}")
58
+ return model_path
59
+ except Exception as e:
60
+ print(f"Error downloading {filename}: {str(e)}")
61
+ raise
62
+
63
+
64
+ async def load_models():
65
+ global vgg16_model, fifth_version_model, tokenizer
66
+ if not all([vgg16_model, fifth_version_model, tokenizer]):
67
+ print("Downloading and loading models from Hugging Face Hub...")
68
+
69
+ try:
70
+ # Download models in parallel
71
+ vgg16_path, model_path, tokenizer_path = await asyncio.gather(
72
+ download_model_from_hf("abdallah-03/AI_product_helper_models", "vgg16_feature_extractor.keras"),
73
+ download_model_from_hf("abdallah-03/AI_product_helper_models", "fifth_version_model.keras"),
74
+ download_model_from_hf("abdallah-03/AI_product_helper_models", "tokenizer.pkl")
75
+ )
76
+
77
+ # Load models using the downloaded paths
78
+ vgg16_task = asyncio.to_thread(load_model_from_path, vgg16_path)
79
+ fifth_version_task = asyncio.to_thread(load_model_from_path, model_path)
80
+ tokenizer_task = asyncio.to_thread(tokenizer_load, tokenizer_path)
81
+
82
+ vgg16_model, fifth_version_model, tokenizer = await asyncio.gather(
83
+ vgg16_task, fifth_version_task, tokenizer_task
84
+ )
85
+ print("Models loaded successfully!")
86
+
87
+ except Exception as e:
88
+ print(f"Error loading models: {str(e)}")
89
+ raise
90
+
91
+
92
+ @app.on_event("startup")
93
+ async def startup_event():
94
+ asyncio.create_task(load_models())
95
+
96
+
97
+ # Pydantic Models
98
+ class ImagePathsRequest(BaseModel):
99
+ image_paths: List[str]
100
+
101
+
102
+ class GenerateProductRequest(ImagePathsRequest):
103
+ Brand_name: str
104
+
105
+
106
+ class GenerateDescriptionRequest(BaseModel):
107
+ product_name: str
108
+
109
+
110
+ class AIproducthelper(ImagePathsRequest):
111
+ Brand_name: str
112
+
113
+
114
+ # Exception Handlers
115
+ @app.exception_handler(Exception)
116
+ async def global_exception_handler(request: Request, exc: Exception):
117
+ return JSONResponse(
118
+ status_code=500,
119
+ content={"success": False, "message": "Internal Server Error", "error": repr(exc)},
120
+ )
121
+
122
+
123
+ @app.exception_handler(HTTPException)
124
+ async def http_exception_handler(request: Request, exc: HTTPException):
125
+ return JSONResponse(
126
+ status_code=exc.status_code,
127
+ content={"success": False, "message": exc.detail},
128
+ )
129
+
130
+
131
+ @app.exception_handler(RequestValidationError)
132
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
133
+ return JSONResponse(
134
+ status_code=422,
135
+ content={"success": False, "message": "Validation Error", "errors": exc.errors()},
136
+ )
137
+
138
+
139
+ # Endpoints
140
+ @app.get("/")
141
+ async def read_root():
142
+ return {"message": "Hello from our API, models are loading in the background!"}
143
+
144
+
145
+ @app.get("/status/")
146
+ async def check_status():
147
+ if all([vgg16_model, fifth_version_model, tokenizer]):
148
+ return {
149
+ "success": True,
150
+ "message": "Models are ready!",
151
+ "models_loaded": {
152
+ "vgg16": vgg16_model is not None,
153
+ "fifth_version": fifth_version_model is not None,
154
+ "tokenizer": tokenizer is not None
155
+ }
156
+ }
157
+ return {
158
+ "success": False,
159
+ "message": "Models are still loading...",
160
+ "models_loaded": {
161
+ "vgg16": vgg16_model is not None,
162
+ "fifth_version": fifth_version_model is not None,
163
+ "tokenizer": tokenizer is not None
164
+ }
165
+ }
166
+
167
+
168
+ @app.post("/extract-colors/")
169
+ async def extract_colors_endpoint(request: ImagePathsRequest):
170
+ if not request.image_paths:
171
+ raise HTTPException(status_code=400, detail="Image list cannot be empty.")
172
+
173
+ try:
174
+ colors = await asyncio.get_event_loop().run_in_executor(executor, extract_colors, request.image_paths)
175
+ return {"success": True, "colors": colors}
176
+ except Exception as exc:
177
+ raise HTTPException(status_code=500, detail=f"Error extracting colors: {repr(exc)}")
178
+
179
+
180
+ @app.post("/generate-product-name/")
181
+ async def generate_product_name_endpoint(request: GenerateProductRequest):
182
+ if not request.image_paths:
183
+ raise HTTPException(status_code=400, detail="Image list cannot be empty.")
184
+
185
+ try:
186
+ product_name = await asyncio.get_event_loop().run_in_executor(
187
+ executor, generate_product_name, request.image_paths, request.Brand_name,
188
+ vgg16_model, fifth_version_model, tokenizer, API_KEY
189
+ )
190
+ return {"success": True, "product_name": product_name}
191
+ except Exception as exc:
192
+ raise HTTPException(status_code=500, detail=f"Error generating product name: {repr(exc)}")
193
+
194
+
195
+ @app.post("/generate-description/")
196
+ async def generate_description_endpoint(request: GenerateDescriptionRequest):
197
+ try:
198
+ description = await asyncio.get_event_loop().run_in_executor(
199
+ executor, generate_description, API_KEY, request.product_name,
200
+ vgg16_model, fifth_version_model, tokenizer
201
+ )
202
+ return {"success": True, "description": description}
203
+ except Exception as exc:
204
+ raise HTTPException(status_code=500, detail=f"Error generating description: {repr(exc)}")
205
+
206
+
207
+ @app.post("/AI-product_help/")
208
+ async def ai_product_help_endpoint(request: AIproducthelper):
209
+ if not request.image_paths:
210
+ raise HTTPException(status_code=400, detail="Image list cannot be empty.")
211
+
212
+ try:
213
+ product_name = await asyncio.get_event_loop().run_in_executor(
214
+ executor, generate_product_name, request.image_paths, request.Brand_name,
215
+ vgg16_model, fifth_version_model, tokenizer, API_KEY
216
+ )
217
+ product_name = clean_response(product_name)
218
+
219
+ description = await asyncio.get_event_loop().run_in_executor(
220
+ executor, generate_description, API_KEY, product_name,
221
+ vgg16_model, fifth_version_model, tokenizer
222
+ )
223
+ description = clean_response(description)
224
+
225
+ return {"success": True, "product_name": product_name, "description": description}
226
+
227
+ except Exception as exc:
228
+ raise HTTPException(status_code=500, detail=f"Error in AI product helper: {repr(exc)}")