Adding project files
Browse files- .gitignore +1 -0
- .idea/.gitignore +8 -0
- .idea/Motager-AI-Product-Helper.iml +8 -0
- .idea/inspectionProfiles/Project_Default.xml +16 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- API_main.py +228 -0
- Color_extraction.py +53 -0
- Generate_caption.py +121 -0
- Generate_productName_description.py +23 -0
- Generating_prompt.py +63 -0
- __pycache__/API_main.cpython-311.pyc +0 -0
- __pycache__/Color_extraction.cpython-311.pyc +0 -0
- __pycache__/Generate_caption.cpython-311.pyc +0 -0
- __pycache__/Generate_productName_description.cpython-311.pyc +0 -0
- __pycache__/Generating_prompt.cpython-311.pyc +0 -0
- requirements.txt +0 -0
- runtime.txt +1 -0
- test.py +25 -0
- train.py +172 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/Motager-AI-Product-Helper.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredIdentifiers">
|
| 6 |
+
<list>
|
| 7 |
+
<option value="bayesnet.model.Node" />
|
| 8 |
+
<option value="bayesnet.model.DiscreteDistribution" />
|
| 9 |
+
<option value="bayesnet.model.ConditionalProbabilityTable" />
|
| 10 |
+
<option value="bayesnet.model.BayesianNetwork" />
|
| 11 |
+
<option value="bluetooth" />
|
| 12 |
+
</list>
|
| 13 |
+
</option>
|
| 14 |
+
</inspection_tool>
|
| 15 |
+
</profile>
|
| 16 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/Motager-AI-Product-Helper.iml" filepath="$PROJECT_DIR$/.idea/Motager-AI-Product-Helper.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
API_main.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from fastapi.exceptions import RequestValidationError
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import List
|
| 7 |
+
import asyncio
|
| 8 |
+
import os
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from Generate_caption import load_model_from_path, tokenizer_load
|
| 12 |
+
from Color_extraction import extract_colors
|
| 13 |
+
from Generate_productName_description import generate_product_name, generate_description, clean_response
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
import tempfile
|
| 16 |
+
|
| 17 |
+
app = FastAPI()
|
| 18 |
+
|
| 19 |
+
# CORS Middleware
|
| 20 |
+
app.add_middleware(
|
| 21 |
+
CORSMiddleware,
|
| 22 |
+
allow_origins=["http://localhost:3000"],
|
| 23 |
+
allow_credentials=True,
|
| 24 |
+
allow_methods=["*"],
|
| 25 |
+
allow_headers=["*"],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Load environment variables
|
| 29 |
+
load_dotenv()
|
| 30 |
+
API_KEY = os.getenv("API_KEY")
|
| 31 |
+
if not API_KEY:
|
| 32 |
+
raise ValueError("API_KEY not set. Please configure your .env file or system environment.")
|
| 33 |
+
|
| 34 |
+
# Global variables for models and ThreadPool
|
| 35 |
+
vgg16_model = None
|
| 36 |
+
fifth_version_model = None
|
| 37 |
+
tokenizer = None
|
| 38 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
| 39 |
+
|
| 40 |
+
# Ensure ONNX model path is set
|
| 41 |
+
os.environ["XDG_CACHE_HOME"] = "models/u2net.onnx"
|
| 42 |
+
|
| 43 |
+
async def download_model_from_hf(repo_id: str, filename: str) -> str:
|
| 44 |
+
try:
|
| 45 |
+
# Create a temporary directory for model files
|
| 46 |
+
model_dir = os.path.join(tempfile.gettempdir(), "hf_models")
|
| 47 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
# Download model
|
| 50 |
+
model_path = hf_hub_download(
|
| 51 |
+
repo_id=repo_id,
|
| 52 |
+
filename=filename,
|
| 53 |
+
cache_dir=model_dir,
|
| 54 |
+
local_dir=model_dir,
|
| 55 |
+
force_download=True
|
| 56 |
+
)
|
| 57 |
+
print(f"Downloaded {filename} to {model_path}")
|
| 58 |
+
return model_path
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error downloading {filename}: {str(e)}")
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
async def load_models():
|
| 65 |
+
global vgg16_model, fifth_version_model, tokenizer
|
| 66 |
+
if not all([vgg16_model, fifth_version_model, tokenizer]):
|
| 67 |
+
print("Downloading and loading models from Hugging Face Hub...")
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Download models in parallel
|
| 71 |
+
vgg16_path, model_path, tokenizer_path = await asyncio.gather(
|
| 72 |
+
download_model_from_hf("abdallah-03/AI_product_helper_models", "vgg16_feature_extractor.keras"),
|
| 73 |
+
download_model_from_hf("abdallah-03/AI_product_helper_models", "fifth_version_model.keras"),
|
| 74 |
+
download_model_from_hf("abdallah-03/AI_product_helper_models", "tokenizer.pkl")
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Load models using the downloaded paths
|
| 78 |
+
vgg16_task = asyncio.to_thread(load_model_from_path, vgg16_path)
|
| 79 |
+
fifth_version_task = asyncio.to_thread(load_model_from_path, model_path)
|
| 80 |
+
tokenizer_task = asyncio.to_thread(tokenizer_load, tokenizer_path)
|
| 81 |
+
|
| 82 |
+
vgg16_model, fifth_version_model, tokenizer = await asyncio.gather(
|
| 83 |
+
vgg16_task, fifth_version_task, tokenizer_task
|
| 84 |
+
)
|
| 85 |
+
print("Models loaded successfully!")
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error loading models: {str(e)}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@app.on_event("startup")
|
| 93 |
+
async def startup_event():
|
| 94 |
+
asyncio.create_task(load_models())
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Pydantic Models
|
| 98 |
+
class ImagePathsRequest(BaseModel):
|
| 99 |
+
image_paths: List[str]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class GenerateProductRequest(ImagePathsRequest):
|
| 103 |
+
Brand_name: str
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class GenerateDescriptionRequest(BaseModel):
|
| 107 |
+
product_name: str
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class AIproducthelper(ImagePathsRequest):
|
| 111 |
+
Brand_name: str
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Exception Handlers
|
| 115 |
+
@app.exception_handler(Exception)
|
| 116 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 117 |
+
return JSONResponse(
|
| 118 |
+
status_code=500,
|
| 119 |
+
content={"success": False, "message": "Internal Server Error", "error": repr(exc)},
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@app.exception_handler(HTTPException)
|
| 124 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
| 125 |
+
return JSONResponse(
|
| 126 |
+
status_code=exc.status_code,
|
| 127 |
+
content={"success": False, "message": exc.detail},
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@app.exception_handler(RequestValidationError)
|
| 132 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 133 |
+
return JSONResponse(
|
| 134 |
+
status_code=422,
|
| 135 |
+
content={"success": False, "message": "Validation Error", "errors": exc.errors()},
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Endpoints
|
| 140 |
+
@app.get("/")
|
| 141 |
+
async def read_root():
|
| 142 |
+
return {"message": "Hello from our API, models are loading in the background!"}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@app.get("/status/")
|
| 146 |
+
async def check_status():
|
| 147 |
+
if all([vgg16_model, fifth_version_model, tokenizer]):
|
| 148 |
+
return {
|
| 149 |
+
"success": True,
|
| 150 |
+
"message": "Models are ready!",
|
| 151 |
+
"models_loaded": {
|
| 152 |
+
"vgg16": vgg16_model is not None,
|
| 153 |
+
"fifth_version": fifth_version_model is not None,
|
| 154 |
+
"tokenizer": tokenizer is not None
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
return {
|
| 158 |
+
"success": False,
|
| 159 |
+
"message": "Models are still loading...",
|
| 160 |
+
"models_loaded": {
|
| 161 |
+
"vgg16": vgg16_model is not None,
|
| 162 |
+
"fifth_version": fifth_version_model is not None,
|
| 163 |
+
"tokenizer": tokenizer is not None
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@app.post("/extract-colors/")
|
| 169 |
+
async def extract_colors_endpoint(request: ImagePathsRequest):
|
| 170 |
+
if not request.image_paths:
|
| 171 |
+
raise HTTPException(status_code=400, detail="Image list cannot be empty.")
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
colors = await asyncio.get_event_loop().run_in_executor(executor, extract_colors, request.image_paths)
|
| 175 |
+
return {"success": True, "colors": colors}
|
| 176 |
+
except Exception as exc:
|
| 177 |
+
raise HTTPException(status_code=500, detail=f"Error extracting colors: {repr(exc)}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@app.post("/generate-product-name/")
|
| 181 |
+
async def generate_product_name_endpoint(request: GenerateProductRequest):
|
| 182 |
+
if not request.image_paths:
|
| 183 |
+
raise HTTPException(status_code=400, detail="Image list cannot be empty.")
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
product_name = await asyncio.get_event_loop().run_in_executor(
|
| 187 |
+
executor, generate_product_name, request.image_paths, request.Brand_name,
|
| 188 |
+
vgg16_model, fifth_version_model, tokenizer, API_KEY
|
| 189 |
+
)
|
| 190 |
+
return {"success": True, "product_name": product_name}
|
| 191 |
+
except Exception as exc:
|
| 192 |
+
raise HTTPException(status_code=500, detail=f"Error generating product name: {repr(exc)}")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@app.post("/generate-description/")
|
| 196 |
+
async def generate_description_endpoint(request: GenerateDescriptionRequest):
|
| 197 |
+
try:
|
| 198 |
+
description = await asyncio.get_event_loop().run_in_executor(
|
| 199 |
+
executor, generate_description, API_KEY, request.product_name,
|
| 200 |
+
vgg16_model, fifth_version_model, tokenizer
|
| 201 |
+
)
|
| 202 |
+
return {"success": True, "description": description}
|
| 203 |
+
except Exception as exc:
|
| 204 |
+
raise HTTPException(status_code=500, detail=f"Error generating description: {repr(exc)}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@app.post("/AI-product_help/")
|
| 208 |
+
async def ai_product_help_endpoint(request: AIproducthelper):
|
| 209 |
+
if not request.image_paths:
|
| 210 |
+
raise HTTPException(status_code=400, detail="Image list cannot be empty.")
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
product_name = await asyncio.get_event_loop().run_in_executor(
|
| 214 |
+
executor, generate_product_name, request.image_paths, request.Brand_name,
|
| 215 |
+
vgg16_model, fifth_version_model, tokenizer, API_KEY
|
| 216 |
+
)
|
| 217 |
+
product_name = clean_response(product_name)
|
| 218 |
+
|
| 219 |
+
description = await asyncio.get_event_loop().run_in_executor(
|
| 220 |
+
executor, generate_description, API_KEY, product_name,
|
| 221 |
+
vgg16_model, fifth_version_model, tokenizer
|
| 222 |
+
)
|
| 223 |
+
description = clean_response(description)
|
| 224 |
+
|
| 225 |
+
return {"success": True, "product_name": product_name, "description": description}
|
| 226 |
+
|
| 227 |
+
except Exception as exc:
|
| 228 |
+
raise HTTPException(status_code=500, detail=f"Error in AI product helper: {repr(exc)}")
|
Color_extraction.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from rembg import remove
|
| 3 |
+
import numpy as np
|
| 4 |
+
import requests
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from sklearn.cluster import KMeans
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def download_image(image_url):
|
| 11 |
+
try:
|
| 12 |
+
response = requests.get(image_url, stream=True, timeout=5)
|
| 13 |
+
response.raise_for_status()
|
| 14 |
+
return Image.open(BytesIO(response.content)).convert("RGBA")
|
| 15 |
+
except requests.exceptions.RequestException as e:
|
| 16 |
+
raise ValueError(f"Error downloading image: {e}")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_image(image_path_or_url):
|
| 20 |
+
return download_image(image_path_or_url) if image_path_or_url.startswith("http") else Image.open(image_path_or_url).convert("RGBA")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def process_image(image):
|
| 24 |
+
output_image = remove(image)
|
| 25 |
+
mask = np.array(output_image)[:, :, 3] > 0 if output_image.mode == 'RGBA' else np.ones(output_image.size[::-1], dtype=bool)
|
| 26 |
+
return output_image, mask
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def extract_dominant_colors(image, mask, color_count=2):
|
| 30 |
+
img_array = np.array(image)
|
| 31 |
+
product_pixels = img_array[mask][:, :3] if img_array.shape[-1] == 4 else img_array[mask]
|
| 32 |
+
|
| 33 |
+
if len(product_pixels) == 0:
|
| 34 |
+
return None # No valid pixels found
|
| 35 |
+
|
| 36 |
+
kmeans = KMeans(n_clusters=color_count, random_state=42, n_init="auto") # Auto-tuned for efficiency
|
| 37 |
+
kmeans.fit(product_pixels)
|
| 38 |
+
return ['#{:02x}{:02x}{:02x}'.format(*map(int, color)) for color in kmeans.cluster_centers_]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def process_single_image(image_path_or_url, color_count):
|
| 42 |
+
try:
|
| 43 |
+
image = load_image(image_path_or_url)
|
| 44 |
+
processed_image, mask = process_image(image)
|
| 45 |
+
return extract_dominant_colors(processed_image, mask, color_count)[0] # Return first dominant color
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Error processing image {image_path_or_url}: {e}")
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def extract_colors(images_list, color_count=2):
|
| 52 |
+
with ThreadPoolExecutor() as executor:
|
| 53 |
+
return list(executor.map(lambda img: process_single_image(img, color_count), images_list))
|
Generate_caption.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import requests
|
| 4 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
| 5 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 6 |
+
import numpy as np
|
| 7 |
+
from keras.src.utils import pad_sequences
|
| 8 |
+
from matplotlib import pyplot as plt
|
| 9 |
+
from keras.models import load_model
|
| 10 |
+
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
|
| 11 |
+
from tensorflow.keras.preprocessing.image import load_img, img_to_array
|
| 12 |
+
from tensorflow.keras.preprocessing.text import Tokenizer
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
def load_model_from_path(model_path):
|
| 16 |
+
model_link=os.path.abspath(model_path)
|
| 17 |
+
if os.path.exists(model_link):
|
| 18 |
+
try:
|
| 19 |
+
model = load_model(model_link)
|
| 20 |
+
print(f"Model from {model_link} loaded successfully!")
|
| 21 |
+
return model
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"Error loading model from {model_link}: {e}")
|
| 24 |
+
else:
|
| 25 |
+
print(f"File not found: {model_link}")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def tokenizer_load(path):
|
| 29 |
+
with open(path, 'rb') as file:
|
| 30 |
+
tokenizer = pickle.load(file)
|
| 31 |
+
return tokenizer
|
| 32 |
+
|
| 33 |
+
def download_image(url, save_path):
|
| 34 |
+
try:
|
| 35 |
+
response = requests.get(url, stream=True, timeout=10)
|
| 36 |
+
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx)
|
| 37 |
+
with open(save_path, 'wb') as file:
|
| 38 |
+
for chunk in response.iter_content(1024):
|
| 39 |
+
file.write(chunk)
|
| 40 |
+
return save_path
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error downloading image {url}: {e}")
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def extract_image_features_one(model, img_path):
|
| 47 |
+
try:
|
| 48 |
+
if img_path.startswith("http"):
|
| 49 |
+
temp_path = "temp_image.jpg"
|
| 50 |
+
img_path = download_image(img_path, temp_path)
|
| 51 |
+
if img_path is None:
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
if not os.path.exists(img_path):
|
| 55 |
+
print(f"Error: Image path does not exist - {img_path}")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
image = load_img(img_path, target_size=(224, 224))
|
| 59 |
+
img_array = img_to_array(image)
|
| 60 |
+
img_array = np.expand_dims(img_array, axis=0)
|
| 61 |
+
img_array = preprocess_input(img_array)
|
| 62 |
+
|
| 63 |
+
feature = model.predict(img_array, verbose=0)
|
| 64 |
+
|
| 65 |
+
if feature is None:
|
| 66 |
+
print(f"Error: Model returned None for image - {img_path}")
|
| 67 |
+
|
| 68 |
+
return feature
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Exception in feature extraction: {e}")
|
| 71 |
+
return None
|
| 72 |
+
finally:
|
| 73 |
+
if temp_path and os.path.exists(temp_path):
|
| 74 |
+
os.remove(temp_path)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def idx_to_word(integer,tokenizer):
|
| 78 |
+
for word ,index in tokenizer.word_index.items():
|
| 79 |
+
if index == integer:
|
| 80 |
+
return word
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
def extract_captions(mapping):
|
| 84 |
+
captions_list = []
|
| 85 |
+
for key in mapping:
|
| 86 |
+
captions_list.extend(mapping[key])
|
| 87 |
+
return captions_list
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def prepare_tokenizer(captions_list):
|
| 91 |
+
tokenizer = Tokenizer()
|
| 92 |
+
tokenizer.fit_on_texts(captions_list)
|
| 93 |
+
vocab_size = len(tokenizer.word_index) + 1
|
| 94 |
+
return tokenizer, vocab_size
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def calculate_max_length(captions_list):
|
| 98 |
+
return max(len(caption.split()) for caption in captions_list)
|
| 99 |
+
|
| 100 |
+
def predict_caption(model, image, tokenizer, max_length):
|
| 101 |
+
in_text = 'startseq'
|
| 102 |
+
for i in range(max_length):
|
| 103 |
+
sequence = tokenizer.texts_to_sequences([in_text])[0]
|
| 104 |
+
sequence = pad_sequences([sequence], maxlen=max_length, padding='post')
|
| 105 |
+
yhat = model.predict([image, sequence], verbose=0)
|
| 106 |
+
yhat = np.argmax(yhat)
|
| 107 |
+
word = idx_to_word(yhat, tokenizer)
|
| 108 |
+
if word is None:
|
| 109 |
+
break
|
| 110 |
+
in_text += " " + word
|
| 111 |
+
if word == 'endseq':
|
| 112 |
+
break
|
| 113 |
+
return in_text
|
| 114 |
+
|
| 115 |
+
def generate_caption(image_path,vgg16_model,model,tokenizer):
|
| 116 |
+
features_image = extract_image_features_one(vgg16_model, image_path)
|
| 117 |
+
if features_image is None:
|
| 118 |
+
print("Error: No features extracted from the image.")
|
| 119 |
+
y_pred = predict_caption(model, features_image, tokenizer, 18)
|
| 120 |
+
return y_pred
|
| 121 |
+
|
Generate_productName_description.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Generating_prompt import generate_product_name_prompt , generate_description_prompt
|
| 2 |
+
import google.generativeai as genai
|
| 3 |
+
|
| 4 |
+
def clean_response(text: str) -> str:
|
| 5 |
+
return text.replace("\n", " ").strip()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def generate_product_name(image_path_list, Brand_name, vgg16_model, model, tokenizer, api_key):
|
| 9 |
+
prompt = generate_product_name_prompt(image_path_list, Brand_name, vgg16_model, model, tokenizer)
|
| 10 |
+
genai.configure(api_key=api_key)
|
| 11 |
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
| 12 |
+
response = model.generate_content(prompt)
|
| 13 |
+
|
| 14 |
+
return str(response.text) if response.text else ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def generate_description(api_key, product_name, vgg16_model, model, tokenizer):
|
| 18 |
+
prompt = generate_description_prompt(product_name, vgg16_model, model, tokenizer)
|
| 19 |
+
genai.configure(api_key=api_key)
|
| 20 |
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
| 21 |
+
response = model.generate_content(prompt)
|
| 22 |
+
return str(response.text) if response.text else ""
|
| 23 |
+
|
Generating_prompt.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
+
from Generate_caption import generate_caption
|
| 4 |
+
from Color_extraction import extract_colors
|
| 5 |
+
from soupsieve.util import lower
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def generate_product_name_prompt(image_path,Brand_name,vgg16_model,model,tokenizer):
|
| 9 |
+
caption = generate_caption(image_path[0],vgg16_model,model,tokenizer)
|
| 10 |
+
if (str(Brand_name).lower()) == "none" or str(Brand_name).lower()=="generic":
|
| 11 |
+
statment = (f"Generate a product title name based on the caption {caption} as following information."
|
| 12 |
+
)
|
| 13 |
+
else:
|
| 14 |
+
statment = (f"Generate a product title name based on the caption {caption} and {Brand_name} as following information."
|
| 15 |
+
f"Replace the brand name with {Brand_name}. if it is not 'none' or 'None' or 'Generic'"
|
| 16 |
+
f"remember if {Brand_name} is 'none' or 'None' or 'Generic' do not depend on it and exclude it from product name"
|
| 17 |
+
f" Ensure the product title follows this format:<{Brand_name}> <Product Details>. "
|
| 18 |
+
)
|
| 19 |
+
prompt = (f"{statment}"
|
| 20 |
+
f"reformat the {caption} to be professional product title like in the Amazon for website"
|
| 21 |
+
f"Ensure that is only one product title"
|
| 22 |
+
f"The product details should include features like product type and it must be somthing popular, series name, purpose "
|
| 23 |
+
f"and any relevant specifics"
|
| 24 |
+
f"Do NOT use escape characters or newline (\n)."
|
| 25 |
+
f"excluding those words (startseq) and (endseq) removing any extra spaces."
|
| 26 |
+
f"excluding any color and brand name from the product title without any(:) and (,)."
|
| 27 |
+
f"example: 'Adidas T-Shirts Round Neck Cotton Full Sleeve'"
|
| 28 |
+
f"do not say specific model number for the product title."
|
| 29 |
+
|
| 30 |
+
f"example if brand name is Apple and caption is smartphone provide that it is iphone but do not provide it's model number as (15 pro max) "
|
| 31 |
+
f"examples: iphone [no] pro max , Samsung S[no] ultra , Samsung A[no] get the model but not be very specific"
|
| 32 |
+
f"do not generate none or generic in product name"
|
| 33 |
+
|
| 34 |
+
)
|
| 35 |
+
return prompt
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def generate_description_prompt(product_name, vgg16_model, model, tokenizer):
|
| 39 |
+
prompt = (
|
| 40 |
+
f'Generate a product description with the following sections: "About this item" and "Product description".'
|
| 41 |
+
f'Based on this information:'
|
| 42 |
+
f'Product Title: {product_name}'
|
| 43 |
+
f'Important Requirements:'
|
| 44 |
+
f'1. Limit the description to exactly 150 words.'
|
| 45 |
+
f'2. Extract the brand name from the Product Title below and use it to reference the product within the description.'
|
| 46 |
+
f'3. Follow the structure provided below for "About this item" and "Product description".'
|
| 47 |
+
f'4. Ensure each line in the description contains two sentences, removing unnecessary spaces after periods (.).'
|
| 48 |
+
f'5. Do NOT use escape characters or newline (\n).'
|
| 49 |
+
|
| 50 |
+
f'Expected Output Format:'
|
| 51 |
+
|
| 52 |
+
f'About this item Genuine leather construction for lasting durability.Multiple card slots and compartments for organization.Sleek and sophisticated design for a polished look.Compact size for easy carrying in pockets or bags.Secure closure to protect your valuables.Product Description.The polo leather wallet offers a premium feel and functionality. It\'s crafted from high-quality leather, ensuring both style and longevity.Its thoughtful design includes ample space for cards and cash. The compact size makes it ideal for everyday use.This polo leather wallet is a perfect blend of practicality and sophistication. It’s designed for the modern gentleman who appreciates quality.'
|
| 53 |
+
f'Remember to:'
|
| 54 |
+
f'remove any ("\n") in response'
|
| 55 |
+
f'Each bullet in "About this item" should only have a maximum of 6 words.'
|
| 56 |
+
f'Ensure each line in the description contains two sentences.'
|
| 57 |
+
f'Remove and exclude extra spaces after (.).'
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return prompt
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
__pycache__/API_main.cpython-311.pyc
ADDED
|
Binary file (7.28 kB). View file
|
|
|
__pycache__/Color_extraction.cpython-311.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
__pycache__/Generate_caption.cpython-311.pyc
ADDED
|
Binary file (5.72 kB). View file
|
|
|
__pycache__/Generate_productName_description.cpython-311.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
__pycache__/Generating_prompt.cpython-311.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
requirements.txt
ADDED
|
Binary file (490 Bytes). View file
|
|
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.11
|
test.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Color_extraction import extract_colors
|
| 2 |
+
# from Generate_productName_description import generate_product_name, generate_description
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import os
|
| 5 |
+
from Generate_caption import extract_image_features_one
|
| 6 |
+
# from Generate_productName_description import clean_response
|
| 7 |
+
# Load environment variables
|
| 8 |
+
load_dotenv()
|
| 9 |
+
API_KEY = os.getenv("API_KEY")
|
| 10 |
+
|
| 11 |
+
if not API_KEY:
|
| 12 |
+
raise ValueError("API_KEY not set. Please configure your .env file or system environment.")
|
| 13 |
+
|
| 14 |
+
# image_path_list = ['https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRcbfffkBR71xadfZ38APy1tclW2zQ77c6--g&s']
|
| 15 |
+
#
|
| 16 |
+
# product_name = generate_product_name(image_path_list,'Samsung',API_KEY)
|
| 17 |
+
# print(product_name)
|
| 18 |
+
# text = "None"
|
| 19 |
+
# print((text.lower()))
|
| 20 |
+
# color_list = extract_colors(image_path_list)
|
| 21 |
+
# print(color_list)
|
| 22 |
+
# description = generate_description(image_path_list,API_KEY,product_name,color_list)
|
| 23 |
+
# print(description)
|
| 24 |
+
# image = url_to_cv2_image("https://duuw10jl1n.ufs.sh/f/URa8oGmtpSmeY9aosOAeRgyf9hO1udBMVQv2tTG7YlCD8XLi")
|
| 25 |
+
# print(image)
|
train.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from Generate_caption import load_model
|
| 6 |
+
from tensorflow.keras.preprocessing.image import load_img, img_to_array
|
| 7 |
+
from tensorflow.keras.applications.vgg16 import preprocess_input
|
| 8 |
+
from tensorflow.keras.preprocessing.text import Tokenizer
|
| 9 |
+
from tensorflow.keras.utils import pad_sequences, to_categorical
|
| 10 |
+
from tensorflow.keras.models import Model, load_model
|
| 11 |
+
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Embedding, GRU, add, LayerNormalization
|
| 12 |
+
from tensorflow.keras.optimizers import Adam
|
| 13 |
+
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_image_features(model, image_folder):
|
| 17 |
+
features = {}
|
| 18 |
+
directory = os.path( image_folder)
|
| 19 |
+
for item in tqdm(os.listdir(directory), desc="Extracting Features"):
|
| 20 |
+
item_path = os.path.join(directory, item)
|
| 21 |
+
if os.path.isfile(item_path):
|
| 22 |
+
try:
|
| 23 |
+
image = load_img(item_path, target_size=(224, 224))
|
| 24 |
+
img_array = img_to_array(image)
|
| 25 |
+
img_array = img_array.reshape((1, img_array.shape[0], img_array.shape[1], img_array.shape[2]))
|
| 26 |
+
img_array = preprocess_input(img_array)
|
| 27 |
+
feature = model.predict(img_array, verbose=0)
|
| 28 |
+
image_id = item.split('.')[0]
|
| 29 |
+
features[image_id] = feature
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Error processing image {item_path}: {e}")
|
| 32 |
+
return features
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def read_captions_file(file_path):
|
| 36 |
+
try:
|
| 37 |
+
with open(file_path, 'r') as file:
|
| 38 |
+
next(file)
|
| 39 |
+
captions = file.read()
|
| 40 |
+
return captions
|
| 41 |
+
except Exception as e:
|
| 42 |
+
raise RuntimeError(f"Error reading the file: {e}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def create_image_caption_mapping(captions):
|
| 46 |
+
mapping = {}
|
| 47 |
+
for line in tqdm(captions.split('\n'), desc="Processing Captions"):
|
| 48 |
+
tokens = line.split(',')
|
| 49 |
+
if len(tokens) < 2:
|
| 50 |
+
continue
|
| 51 |
+
image_id, caption = tokens[0], tokens[1:]
|
| 52 |
+
caption = " ".join(caption)
|
| 53 |
+
if image_id not in mapping:
|
| 54 |
+
mapping[image_id] = []
|
| 55 |
+
mapping[image_id].append(caption)
|
| 56 |
+
return mapping
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def preprocess_text(mapping):
|
| 60 |
+
for key, captions in mapping.items():
|
| 61 |
+
for i in range(len(captions)):
|
| 62 |
+
caption = captions[i].lower()
|
| 63 |
+
caption = caption.replace('[^A-Za-z]', ' ').replace('\s+', ' ')
|
| 64 |
+
caption = 'startseq ' + " ".join([word for word in caption.split() if len(word) > 1]) + ' endseq'
|
| 65 |
+
captions[i] = caption
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_captions(mapping):
|
| 69 |
+
captions_list = []
|
| 70 |
+
for key in mapping:
|
| 71 |
+
captions_list.extend(mapping[key])
|
| 72 |
+
return captions_list
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def prepare_tokenizer(captions_list):
|
| 76 |
+
tokenizer = Tokenizer()
|
| 77 |
+
tokenizer.fit_on_texts(captions_list)
|
| 78 |
+
vocab_size = len(tokenizer.word_index) + 1
|
| 79 |
+
return tokenizer, vocab_size
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def calculate_max_length(captions_list):
|
| 83 |
+
return max(len(caption.split()) for caption in captions_list)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def split(image_ids, train_ratio, val_ratio=None):
|
| 87 |
+
random.shuffle(image_ids)
|
| 88 |
+
total = len(image_ids)
|
| 89 |
+
train_split = int(total * train_ratio)
|
| 90 |
+
val_split = int(total * (train_ratio + val_ratio)) if val_ratio else train_split
|
| 91 |
+
train_ids = image_ids[:train_split]
|
| 92 |
+
val_ids = image_ids[train_split:val_split] if val_ratio else []
|
| 93 |
+
test_ids = image_ids[val_split:]
|
| 94 |
+
return train_ids, val_ids, test_ids
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def data_generator(data_keys, mapping, features, tokenizer, max_length, vocab_size, batch_size):
|
| 98 |
+
X1, X2, y = [], [], []
|
| 99 |
+
n = 0
|
| 100 |
+
while True:
|
| 101 |
+
for key in data_keys:
|
| 102 |
+
n += 1
|
| 103 |
+
captions = mapping[key]
|
| 104 |
+
for caption in captions:
|
| 105 |
+
seq = tokenizer.texts_to_sequences([caption])[0]
|
| 106 |
+
for i in range(1, len(seq)):
|
| 107 |
+
in_seq, out_seq = seq[:i], seq[i]
|
| 108 |
+
in_seq = pad_sequences([in_seq], maxlen=max_length, padding='post')[0]
|
| 109 |
+
out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]
|
| 110 |
+
X1.append(features[key][0])
|
| 111 |
+
X2.append(in_seq)
|
| 112 |
+
y.append(out_seq)
|
| 113 |
+
if n == batch_size:
|
| 114 |
+
yield {"image": np.array(X1), "text": np.array(X2)}, np.array(y)
|
| 115 |
+
X1, X2, y = [], [], []
|
| 116 |
+
n = 0
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def build_model(vocab_size, max_length):
|
| 120 |
+
inputs1 = Input(shape=(4096,), name="image")
|
| 121 |
+
fe1 = Dropout(0.4)(inputs1)
|
| 122 |
+
fe2 = Dense(256, activation='relu')(fe1)
|
| 123 |
+
fe3 = BatchNormalization()(fe2)
|
| 124 |
+
|
| 125 |
+
inputs2 = Input(shape=(max_length,), name="text")
|
| 126 |
+
se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)
|
| 127 |
+
se2 = Dropout(0.4)(se1)
|
| 128 |
+
se3 = GRU(256, recurrent_dropout=0.3, return_sequences=False)(se2)
|
| 129 |
+
|
| 130 |
+
decoder1 = add([fe3, se3])
|
| 131 |
+
decoder2 = LayerNormalization()(decoder1)
|
| 132 |
+
decoder3 = Dense(512, activation='relu')(decoder2)
|
| 133 |
+
decoder4 = Dropout(0.3)(decoder3)
|
| 134 |
+
outputs = Dense(vocab_size, activation='softmax')(decoder4)
|
| 135 |
+
|
| 136 |
+
model = Model(inputs=[inputs1, inputs2], outputs=outputs)
|
| 137 |
+
optimizer = Adam(learning_rate=0.001)
|
| 138 |
+
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
|
| 139 |
+
|
| 140 |
+
return model
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_existing_or_new_model(vocab_size, max_length, model_path="seven_version_model.keras"):
|
| 144 |
+
if os.path.exists(model_path):
|
| 145 |
+
print("Loading existing model...")
|
| 146 |
+
return load_model(model_path)
|
| 147 |
+
else:
|
| 148 |
+
print("No existing model found. Creating a new one...")
|
| 149 |
+
return build_model(vocab_size, max_length)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def continue_training(model, train, val, mapping, features, tokenizer, max_length, vocab_size, batch_size, epochs):
|
| 153 |
+
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose=1)
|
| 154 |
+
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
|
| 155 |
+
|
| 156 |
+
steps = len(train) // batch_size
|
| 157 |
+
|
| 158 |
+
for i in range(epochs):
|
| 159 |
+
print(f"Epoch {i + 1}/{epochs}")
|
| 160 |
+
generator = data_generator(train, mapping, features, tokenizer, max_length, vocab_size, batch_size)
|
| 161 |
+
validation_generator = data_generator(val, mapping, features, tokenizer, max_length, vocab_size, batch_size)
|
| 162 |
+
|
| 163 |
+
model.fit(generator, validation_data=validation_generator, epochs=1, steps_per_epoch=steps,
|
| 164 |
+
validation_steps=len(val) // batch_size, verbose=1, callbacks=[early_stopping, lr_scheduler])
|
| 165 |
+
|
| 166 |
+
model.save("seven_version_model.keras")
|
| 167 |
+
print("Updated model saved successfully.")
|
| 168 |
+
|
| 169 |
+
#
|
| 170 |
+
# model = load_existing_or_new_model(vocab_size, max_length)
|
| 171 |
+
# continue_training(model, train_ids, val_ids, mapping, features, tokenizer, max_length, vocab_size, batch_size=64,
|
| 172 |
+
# epochs=10)
|