Robaa commited on
Commit
e09d4fa
·
verified ·
1 Parent(s): 7134ae0

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -228
app.py DELETED
@@ -1,228 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
- from fastapi.responses import JSONResponse
3
- from fastapi.exceptions import RequestValidationError
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- from typing import List
7
- import asyncio
8
- import os
9
- from concurrent.futures import ThreadPoolExecutor
10
- from dotenv import load_dotenv
11
- from Generate_caption import load_model_from_path, tokenizer_load
12
- from Color_extraction import extract_colors
13
- from Generate_productName_description import generate_product_name, generate_description, clean_response
14
- from huggingface_hub import hf_hub_download
15
- import tempfile
16
-
17
- app = FastAPI()
18
-
19
- # CORS Middleware
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["http://localhost:3000"],
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
-
28
- # Load environment variables
29
- load_dotenv()
30
- API_KEY = os.getenv("API_KEY")
31
- if not API_KEY:
32
- raise ValueError("API_KEY not set. Please configure your .env file or system environment.")
33
-
34
- # Global variables for models and ThreadPool
35
- vgg16_model = None
36
- fifth_version_model = None
37
- tokenizer = None
38
- executor = ThreadPoolExecutor(max_workers=4)
39
-
40
- # Ensure ONNX model path is set
41
- os.environ["XDG_CACHE_HOME"] = "models/u2net.onnx"
42
-
43
- async def download_model_from_hf(repo_id: str, filename: str) -> str:
44
- try:
45
- # Create a temporary directory for model files
46
- model_dir = os.path.join(tempfile.gettempdir(), "hf_models")
47
- os.makedirs(model_dir, exist_ok=True)
48
-
49
- # Download model
50
- model_path = hf_hub_download(
51
- repo_id=repo_id,
52
- filename=filename,
53
- cache_dir=model_dir,
54
- local_dir=model_dir,
55
- force_download=True
56
- )
57
- print(f"Downloaded {filename} to {model_path}")
58
- return model_path
59
- except Exception as e:
60
- print(f"Error downloading {filename}: {str(e)}")
61
- raise
62
-
63
-
64
- async def load_models():
65
- global vgg16_model, fifth_version_model, tokenizer
66
- if not all([vgg16_model, fifth_version_model, tokenizer]):
67
- print("Downloading and loading models from Hugging Face Hub...")
68
-
69
- try:
70
- # Download models in parallel
71
- vgg16_path, model_path, tokenizer_path = await asyncio.gather(
72
- download_model_from_hf("abdallah-03/AI_product_helper_models", "vgg16_feature_extractor.keras"),
73
- download_model_from_hf("abdallah-03/AI_product_helper_models", "fifth_version_model.keras"),
74
- download_model_from_hf("abdallah-03/AI_product_helper_models", "tokenizer.pkl")
75
- )
76
-
77
- # Load models using the downloaded paths
78
- vgg16_task = asyncio.to_thread(load_model_from_path, vgg16_path)
79
- fifth_version_task = asyncio.to_thread(load_model_from_path, model_path)
80
- tokenizer_task = asyncio.to_thread(tokenizer_load, tokenizer_path)
81
-
82
- vgg16_model, fifth_version_model, tokenizer = await asyncio.gather(
83
- vgg16_task, fifth_version_task, tokenizer_task
84
- )
85
- print("Models loaded successfully!")
86
-
87
- except Exception as e:
88
- print(f"Error loading models: {str(e)}")
89
- raise
90
-
91
-
92
- @app.on_event("startup")
93
- async def startup_event():
94
- asyncio.create_task(load_models())
95
-
96
-
97
- # Pydantic Models
98
- class ImagePathsRequest(BaseModel):
99
- image_paths: List[str]
100
-
101
-
102
- class GenerateProductRequest(ImagePathsRequest):
103
- Brand_name: str
104
-
105
-
106
- class GenerateDescriptionRequest(BaseModel):
107
- product_name: str
108
-
109
-
110
- class AIproducthelper(ImagePathsRequest):
111
- Brand_name: str
112
-
113
-
114
- # Exception Handlers
115
- @app.exception_handler(Exception)
116
- async def global_exception_handler(request: Request, exc: Exception):
117
- return JSONResponse(
118
- status_code=500,
119
- content={"success": False, "message": "Internal Server Error", "error": repr(exc)},
120
- )
121
-
122
-
123
- @app.exception_handler(HTTPException)
124
- async def http_exception_handler(request: Request, exc: HTTPException):
125
- return JSONResponse(
126
- status_code=exc.status_code,
127
- content={"success": False, "message": exc.detail},
128
- )
129
-
130
-
131
- @app.exception_handler(RequestValidationError)
132
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
133
- return JSONResponse(
134
- status_code=422,
135
- content={"success": False, "message": "Validation Error", "errors": exc.errors()},
136
- )
137
-
138
-
139
- # Endpoints
140
- @app.get("/")
141
- async def read_root():
142
- return {"message": "Hello from our API, models are loading in the background!"}
143
-
144
-
145
- @app.get("/status/")
146
- async def check_status():
147
- if all([vgg16_model, fifth_version_model, tokenizer]):
148
- return {
149
- "success": True,
150
- "message": "Models are ready!",
151
- "models_loaded": {
152
- "vgg16": vgg16_model is not None,
153
- "fifth_version": fifth_version_model is not None,
154
- "tokenizer": tokenizer is not None
155
- }
156
- }
157
- return {
158
- "success": False,
159
- "message": "Models are still loading...",
160
- "models_loaded": {
161
- "vgg16": vgg16_model is not None,
162
- "fifth_version": fifth_version_model is not None,
163
- "tokenizer": tokenizer is not None
164
- }
165
- }
166
-
167
-
168
- @app.post("/extract-colors/")
169
- async def extract_colors_endpoint(request: ImagePathsRequest):
170
- if not request.image_paths:
171
- raise HTTPException(status_code=400, detail="Image list cannot be empty.")
172
-
173
- try:
174
- colors = await asyncio.get_event_loop().run_in_executor(executor, extract_colors, request.image_paths)
175
- return {"success": True, "colors": colors}
176
- except Exception as exc:
177
- raise HTTPException(status_code=500, detail=f"Error extracting colors: {repr(exc)}")
178
-
179
-
180
- @app.post("/generate-product-name/")
181
- async def generate_product_name_endpoint(request: GenerateProductRequest):
182
- if not request.image_paths:
183
- raise HTTPException(status_code=400, detail="Image list cannot be empty.")
184
-
185
- try:
186
- product_name = await asyncio.get_event_loop().run_in_executor(
187
- executor, generate_product_name, request.image_paths, request.Brand_name,
188
- vgg16_model, fifth_version_model, tokenizer, API_KEY
189
- )
190
- return {"success": True, "product_name": product_name}
191
- except Exception as exc:
192
- raise HTTPException(status_code=500, detail=f"Error generating product name: {repr(exc)}")
193
-
194
-
195
- @app.post("/generate-description/")
196
- async def generate_description_endpoint(request: GenerateDescriptionRequest):
197
- try:
198
- description = await asyncio.get_event_loop().run_in_executor(
199
- executor, generate_description, API_KEY, request.product_name,
200
- vgg16_model, fifth_version_model, tokenizer
201
- )
202
- return {"success": True, "description": description}
203
- except Exception as exc:
204
- raise HTTPException(status_code=500, detail=f"Error generating description: {repr(exc)}")
205
-
206
-
207
- @app.post("/AI-product_help/")
208
- async def ai_product_help_endpoint(request: AIproducthelper):
209
- if not request.image_paths:
210
- raise HTTPException(status_code=400, detail="Image list cannot be empty.")
211
-
212
- try:
213
- product_name = await asyncio.get_event_loop().run_in_executor(
214
- executor, generate_product_name, request.image_paths, request.Brand_name,
215
- vgg16_model, fifth_version_model, tokenizer, API_KEY
216
- )
217
- product_name = clean_response(product_name)
218
-
219
- description = await asyncio.get_event_loop().run_in_executor(
220
- executor, generate_description, API_KEY, product_name,
221
- vgg16_model, fifth_version_model, tokenizer
222
- )
223
- description = clean_response(description)
224
-
225
- return {"success": True, "product_name": product_name, "description": description}
226
-
227
- except Exception as exc:
228
- raise HTTPException(status_code=500, detail=f"Error in AI product helper: {repr(exc)}")