Spaces:
Build error
Build error
Initial commit
Browse files- .github/workflows/dockerhub.yaml +39 -0
- .gitignore +7 -0
- Dockerfile +21 -0
- app.py +224 -0
- requirements.txt +80 -0
- setup.py +17 -0
- src/components/__init__.py +0 -0
- src/components/__pycache__/__init__.cpython-310.pyc +0 -0
- src/components/__pycache__/necklaceTryOn.cpython-310.pyc +0 -0
- src/components/necklaceTryOn.py +235 -0
- src/pipelines/__init__.py +0 -0
- src/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
- src/pipelines/__pycache__/completePipeline.cpython-310.pyc +0 -0
- src/pipelines/completePipeline.py +13 -0
- src/utils/__init__.py +45 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/backgroundEnhancerArchitecture.cpython-310.pyc +0 -0
- src/utils/__pycache__/exceptions.cpython-310.pyc +0 -0
- src/utils/__pycache__/logger.cpython-310.pyc +0 -0
- src/utils/backgroundEnhancerArchitecture.py +454 -0
- src/utils/exceptions.py +16 -0
- src/utils/logger.py +22 -0
.github/workflows/dockerhub.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish Docker image
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
push_to_registry:
|
| 9 |
+
name: Push Docker image to Docker Hub
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
permissions:
|
| 12 |
+
packages: write
|
| 13 |
+
contents: read
|
| 14 |
+
attestations: write
|
| 15 |
+
steps:
|
| 16 |
+
- name: Check out the repo
|
| 17 |
+
uses: actions/checkout@v4
|
| 18 |
+
|
| 19 |
+
- name: Log in to Docker Hub
|
| 20 |
+
uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
|
| 21 |
+
with:
|
| 22 |
+
username: ${{ secrets.DOCKER_USERNAME }}
|
| 23 |
+
password: ${{ secrets.DOCKER_PASSWORD }}
|
| 24 |
+
|
| 25 |
+
- name: Extract metadata (tags, labels) for Docker
|
| 26 |
+
id: meta
|
| 27 |
+
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
| 28 |
+
with:
|
| 29 |
+
images: ishworrsubedii/web_plugin_api
|
| 30 |
+
|
| 31 |
+
- name: Build and push Docker image
|
| 32 |
+
id: push
|
| 33 |
+
uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
|
| 34 |
+
with:
|
| 35 |
+
context: .
|
| 36 |
+
file: ./Dockerfile
|
| 37 |
+
push: true
|
| 38 |
+
tags: ${{ steps.meta.outputs.tags }}
|
| 39 |
+
labels: ${{ steps.meta.outputs.labels }}
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
.idea/
|
| 6 |
+
|
| 7 |
+
*.log
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /api
|
| 4 |
+
|
| 5 |
+
COPY . /api
|
| 6 |
+
|
| 7 |
+
RUN apt-get update && apt-get install -y
|
| 8 |
+
|
| 9 |
+
RUN apt install libgl1-mesa-glx -y
|
| 10 |
+
|
| 11 |
+
RUN apt-get install 'ffmpeg'\
|
| 12 |
+
'libsm6'\
|
| 13 |
+
'libxext6' -y
|
| 14 |
+
|
| 15 |
+
RUN pip install -r requirements.txt
|
| 16 |
+
|
| 17 |
+
RUN ulimit -s 2000
|
| 18 |
+
|
| 19 |
+
EXPOSE 8000
|
| 20 |
+
|
| 21 |
+
CMD ["uvicorn", "app:app", "--host","0.0.0.0","--port","8000"]
|
app.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi.encoders import jsonable_encoder
|
| 2 |
+
from src.utils import supabaseGetPublicURL, deductAndTrackCredit, returnBytesData
|
| 3 |
+
from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Form, Depends
|
| 4 |
+
from src.pipelines.completePipeline import Pipeline
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
from supabase import create_client, Client
|
| 8 |
+
from typing import Dict, Union, List
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import base64
|
| 13 |
+
import os
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
pipeline = Pipeline()
|
| 17 |
+
app = FastAPI(title="Magical Mirror Web Plugin")
|
| 18 |
+
|
| 19 |
+
app.add_middleware(
|
| 20 |
+
CORSMiddleware,
|
| 21 |
+
allow_origins=["*"],
|
| 22 |
+
allow_credentials=True,
|
| 23 |
+
allow_methods=["*"],
|
| 24 |
+
allow_headers=["*"],
|
| 25 |
+
)
|
| 26 |
+
url: str = os.environ["SUPABASE_URL"]
|
| 27 |
+
key: str = os.environ["SUPABASE_KEY"]
|
| 28 |
+
supabase: Client = create_client(url, key)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.post("/productData/{storeId}")
|
| 32 |
+
async def product_data(
|
| 33 |
+
storeId: str,
|
| 34 |
+
filterattributes: List[Dict[str, Union[str, int, float]]],
|
| 35 |
+
storename: str = Header(default="default")
|
| 36 |
+
):
|
| 37 |
+
"""Filters product data based on the provided attributes and store ID."""
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
response = supabase.table('MagicMirror').select("*").execute()
|
| 41 |
+
df = pd.DataFrame(response.dict()["data"])
|
| 42 |
+
|
| 43 |
+
df = df[df["StoreName"] == storeId]
|
| 44 |
+
|
| 45 |
+
# Preprocess filterattributes to handle multiple or duplicated attributes
|
| 46 |
+
attribute_dict = {}
|
| 47 |
+
for attr in filterattributes:
|
| 48 |
+
key, value = list(attr.items())[
|
| 49 |
+
0] # This will convert the dictionary into a list and get the key and value.
|
| 50 |
+
if key in attribute_dict: # This will check if the key is already present in the dictionary.
|
| 51 |
+
if isinstance(attribute_dict[key],
|
| 52 |
+
list): # This will create a list if there are multiple values for the same key and we are doing or operation.
|
| 53 |
+
attribute_dict[key].append(value) # This will append the value to the list.
|
| 54 |
+
else:
|
| 55 |
+
attribute_dict[key] = [attribute_dict[key], value]
|
| 56 |
+
else:
|
| 57 |
+
attribute_dict[key] = [value] # This will create a list if there is only one value for the key.
|
| 58 |
+
|
| 59 |
+
priceFrom = None
|
| 60 |
+
priceTo = None
|
| 61 |
+
weightFrom = None
|
| 62 |
+
weightTo = None
|
| 63 |
+
weightAscending = None
|
| 64 |
+
priceAscending = None
|
| 65 |
+
idAscending = None
|
| 66 |
+
dateAscending = None
|
| 67 |
+
|
| 68 |
+
for key, value in attribute_dict.items():
|
| 69 |
+
if key == 'priceFrom':
|
| 70 |
+
priceFrom = value[0]
|
| 71 |
+
|
| 72 |
+
elif key == "priceTo":
|
| 73 |
+
priceTo = value[0]
|
| 74 |
+
|
| 75 |
+
elif key == "priceAscending":
|
| 76 |
+
priceAscending = value[0]
|
| 77 |
+
|
| 78 |
+
elif key == "weightFrom":
|
| 79 |
+
weightFrom = value[0]
|
| 80 |
+
|
| 81 |
+
elif key == "weightTo":
|
| 82 |
+
weightTo = value[0]
|
| 83 |
+
|
| 84 |
+
elif key == "weightAscending":
|
| 85 |
+
weightAscending = value[0]
|
| 86 |
+
|
| 87 |
+
elif key == "idAscending":
|
| 88 |
+
idAscending = value[0]
|
| 89 |
+
|
| 90 |
+
elif key == "dateAscending":
|
| 91 |
+
dateAscending = value[0]
|
| 92 |
+
|
| 93 |
+
df["image_url"] = df.apply(
|
| 94 |
+
lambda row: supabaseGetPublicURL(f"{row['StoreName']}/{row['Category']}/image/{row['Id']}.png"),
|
| 95 |
+
axis=1)
|
| 96 |
+
df["thumbnail_url"] = df.apply(
|
| 97 |
+
lambda row: supabaseGetPublicURL(f"{row['StoreName']}/{row['Category']}/thumbnail/{row['Id']}.png"),
|
| 98 |
+
axis=1)
|
| 99 |
+
|
| 100 |
+
df.reset_index(drop=True, inplace=True)
|
| 101 |
+
for key, values in attribute_dict.items():
|
| 102 |
+
try:
|
| 103 |
+
df = df[df[key].isin(values)]
|
| 104 |
+
|
| 105 |
+
except:
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
# applying filter for price and weight
|
| 109 |
+
if priceFrom is not None:
|
| 110 |
+
df = df[df["Price"] >= priceFrom]
|
| 111 |
+
if priceTo is not None:
|
| 112 |
+
df = df[df["Price"] <= priceTo]
|
| 113 |
+
if weightFrom is not None:
|
| 114 |
+
df = df[df["Weight"] >= weightFrom]
|
| 115 |
+
if weightTo is not None:
|
| 116 |
+
df = df[df["Weight"] <= weightTo]
|
| 117 |
+
|
| 118 |
+
if priceAscending is not None:
|
| 119 |
+
if priceAscending == 1:
|
| 120 |
+
value = True
|
| 121 |
+
|
| 122 |
+
else:
|
| 123 |
+
value = False
|
| 124 |
+
df = df.sort_values(by="Price", ascending=value)
|
| 125 |
+
if weightAscending is not None:
|
| 126 |
+
if weightAscending == 1:
|
| 127 |
+
value = True
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
value = False
|
| 131 |
+
df = df.sort_values(by="Weight", ascending=value)
|
| 132 |
+
|
| 133 |
+
if idAscending is not None:
|
| 134 |
+
if idAscending == 1:
|
| 135 |
+
value = True
|
| 136 |
+
else:
|
| 137 |
+
value = False
|
| 138 |
+
df = df.sort_values(by="Id", ascending=value)
|
| 139 |
+
|
| 140 |
+
if dateAscending is not None:
|
| 141 |
+
if dateAscending == 1:
|
| 142 |
+
value = True
|
| 143 |
+
else:
|
| 144 |
+
value = False
|
| 145 |
+
df = df.sort_values(by="UpdatedAt", ascending=value)
|
| 146 |
+
|
| 147 |
+
df = df.drop(["CreatedAt", "EstimatedPrice"], axis=1)
|
| 148 |
+
|
| 149 |
+
result = {}
|
| 150 |
+
for _, row in df.iterrows():
|
| 151 |
+
category = row["Category"]
|
| 152 |
+
if category not in result: # this is for checking duplicate category
|
| 153 |
+
result[category] = []
|
| 154 |
+
result[category].append(row.to_dict())
|
| 155 |
+
|
| 156 |
+
return JSONResponse(content=jsonable_encoder(result)) # this will convert the result into json format.
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
raise HTTPException(status_code=500, detail=f"Failed to fetch or process data: {e}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class NecklaceTryOnIDEntity(BaseModel):
|
| 163 |
+
necklaceImageId: str
|
| 164 |
+
necklaceCategory: str
|
| 165 |
+
storename: str
|
| 166 |
+
api_token: str
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
async def parse_necklace_try_on_id(necklaceImageId: str = Form(...),
|
| 170 |
+
necklaceCategory: str = Form(...),
|
| 171 |
+
storename: str = Form(...),
|
| 172 |
+
api_token: str = Form(...)) -> NecklaceTryOnIDEntity:
|
| 173 |
+
return NecklaceTryOnIDEntity(
|
| 174 |
+
necklaceImageId=necklaceImageId,
|
| 175 |
+
necklaceCategory=necklaceCategory,
|
| 176 |
+
storename=storename,
|
| 177 |
+
api_token=api_token
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@app.post("/necklaceTryOnID")
|
| 182 |
+
async def necklace_try_on_id(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
| 183 |
+
image: UploadFile = File(...)):
|
| 184 |
+
data, _ = supabase.table("APIKeyList").select("*").filter("API_KEY", "eq",
|
| 185 |
+
necklace_try_on_id.api_token).execute()
|
| 186 |
+
|
| 187 |
+
api_key_actual = data[1][0]['API_KEY']
|
| 188 |
+
if api_key_actual != necklace_try_on_id.api_token:
|
| 189 |
+
return JSONResponse(content={"error": "Invalid API Key"}, status_code=401)
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
imageBytes = await image.read()
|
| 193 |
+
|
| 194 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
| 198 |
+
|
| 199 |
+
except:
|
| 200 |
+
error_message = {
|
| 201 |
+
"error": "The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again."
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
return JSONResponse(content=error_message, status_code=404)
|
| 205 |
+
|
| 206 |
+
result, headetText = await pipeline.necklaceTryOn(image=image, jewellery=jewellery,
|
| 207 |
+
storename=necklace_try_on_id.storename)
|
| 208 |
+
|
| 209 |
+
inMemFile = BytesIO()
|
| 210 |
+
result.save(inMemFile, format="WEBP", quality=85)
|
| 211 |
+
outputBytes = inMemFile.getvalue()
|
| 212 |
+
response = {
|
| 213 |
+
"output": f"data:image/WEBP;base64,{base64.b64encode(outputBytes).decode('utf-8')}"
|
| 214 |
+
}
|
| 215 |
+
creditResponse = deductAndTrackCredit(storename=necklace_try_on_id.storename, endpoint="/necklaceTryOnID")
|
| 216 |
+
if creditResponse == "No Credits Available":
|
| 217 |
+
response = {
|
| 218 |
+
"error": "No Credits Remaining"
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
return JSONResponse(content=response)
|
| 222 |
+
|
| 223 |
+
else:
|
| 224 |
+
return JSONResponse(content=response)
|
requirements.txt
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
anyio==4.4.0
|
| 4 |
+
attrs==23.2.0
|
| 5 |
+
certifi==2024.7.4
|
| 6 |
+
cffi==1.16.0
|
| 7 |
+
charset-normalizer==3.3.2
|
| 8 |
+
click==8.1.7
|
| 9 |
+
contourpy==1.2.1
|
| 10 |
+
cvzone==1.6.1
|
| 11 |
+
cycler==0.12.1
|
| 12 |
+
deprecation==2.1.0
|
| 13 |
+
dnspython==2.6.1
|
| 14 |
+
email_validator==2.2.0
|
| 15 |
+
exceptiongroup==1.2.1
|
| 16 |
+
fastapi==0.111.0
|
| 17 |
+
fastapi-cli==0.0.4
|
| 18 |
+
flatbuffers==24.3.25
|
| 19 |
+
fonttools==4.53.1
|
| 20 |
+
gotrue==2.5.4
|
| 21 |
+
h11==0.14.0
|
| 22 |
+
httpcore==1.0.5
|
| 23 |
+
httptools==0.6.1
|
| 24 |
+
httpx==0.27.0
|
| 25 |
+
idna==3.7
|
| 26 |
+
jax==0.4.30
|
| 27 |
+
jaxlib==0.4.30
|
| 28 |
+
Jinja2==3.1.4
|
| 29 |
+
kiwisolver==1.4.5
|
| 30 |
+
markdown-it-py==3.0.0
|
| 31 |
+
MarkupSafe==2.1.5
|
| 32 |
+
matplotlib==3.9.1
|
| 33 |
+
mdurl==0.1.2
|
| 34 |
+
mediapipe==0.10.14
|
| 35 |
+
ml-dtypes==0.4.0
|
| 36 |
+
numpy==2.0.0
|
| 37 |
+
opencv-contrib-python==4.10.0.84
|
| 38 |
+
opencv-python==4.10.0.84
|
| 39 |
+
opt-einsum==3.3.0
|
| 40 |
+
orjson==3.10.6
|
| 41 |
+
packaging==24.1
|
| 42 |
+
pandas==2.2.2
|
| 43 |
+
pillow==10.4.0
|
| 44 |
+
postgrest==0.16.8
|
| 45 |
+
protobuf==4.25.3
|
| 46 |
+
pycparser==2.22
|
| 47 |
+
pydantic==2.8.2
|
| 48 |
+
pydantic_core==2.20.1
|
| 49 |
+
Pygments==2.18.0
|
| 50 |
+
pyparsing==3.1.2
|
| 51 |
+
python-dateutil==2.9.0.post0
|
| 52 |
+
python-dotenv==1.0.1
|
| 53 |
+
python-multipart==0.0.9
|
| 54 |
+
pytz==2024.1
|
| 55 |
+
PyYAML==6.0.1
|
| 56 |
+
realtime==1.0.6
|
| 57 |
+
requests==2.32.3
|
| 58 |
+
rich==13.7.1
|
| 59 |
+
scikit-build==0.18.0
|
| 60 |
+
scipy==1.14.0
|
| 61 |
+
shellingham==1.5.4
|
| 62 |
+
six==1.16.0
|
| 63 |
+
sniffio==1.3.1
|
| 64 |
+
sounddevice==0.4.7
|
| 65 |
+
starlette==0.37.2
|
| 66 |
+
storage3==0.7.6
|
| 67 |
+
StrEnum==0.4.15
|
| 68 |
+
supabase==2.5.1
|
| 69 |
+
supafunc==0.4.6
|
| 70 |
+
tomli==2.0.1
|
| 71 |
+
typer==0.12.3
|
| 72 |
+
typing_extensions==4.12.2
|
| 73 |
+
tzdata==2024.1
|
| 74 |
+
ujson==5.10.0
|
| 75 |
+
urllib3==2.2.2
|
| 76 |
+
uvicorn==0.30.1
|
| 77 |
+
uvloop==0.19.0
|
| 78 |
+
watchfiles==0.22.0
|
| 79 |
+
websockets==12.0
|
| 80 |
+
-e .
|
setup.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
HYPER_E_DOT = "-e ."
|
| 4 |
+
def getRequirements(requirementsPath: str) -> list[str]:
|
| 5 |
+
with open(requirementsPath) as file:
|
| 6 |
+
requirements = file.read().split("\n")
|
| 7 |
+
requirements.remove(HYPER_E_DOT)
|
| 8 |
+
return requirements
|
| 9 |
+
|
| 10 |
+
setup(
|
| 11 |
+
name = "Magical-Mirror",
|
| 12 |
+
author = "Subramani Sivakumar",
|
| 13 |
+
author_email = "bwsubbu@gmail.com",
|
| 14 |
+
version = "0.1",
|
| 15 |
+
packages = find_packages(),
|
| 16 |
+
install_requires = getRequirements(requirementsPath = "./requirements.txt")
|
| 17 |
+
)
|
src/components/__init__.py
ADDED
|
File without changes
|
src/components/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
src/components/__pycache__/necklaceTryOn.cpython-310.pyc
ADDED
|
Binary file (6.35 kB). View file
|
|
|
src/components/necklaceTryOn.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cvzone.FaceMeshModule import FaceMeshDetector
|
| 2 |
+
from src.utils import addWatermark, returnBytesData
|
| 3 |
+
from src.utils.exceptions import CustomException
|
| 4 |
+
from cvzone.PoseModule import PoseDetector
|
| 5 |
+
from src.utils.logger import logger
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Union
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cvzone
|
| 11 |
+
import math
|
| 12 |
+
import cv2
|
| 13 |
+
import gc
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class NecklaceTryOnConfig:
|
| 18 |
+
logoURL: str = "https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/MagicMirror/FullImages/{}.png"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class NecklaceTryOn:
|
| 22 |
+
def __init__(self) -> None:
|
| 23 |
+
self.detector = PoseDetector()
|
| 24 |
+
self.necklaceTryOnConfig = NecklaceTryOnConfig()
|
| 25 |
+
self.meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
| 26 |
+
|
| 27 |
+
def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
| 28 |
+
Union[Image.Image, str]]:
|
| 29 |
+
try:
|
| 30 |
+
logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
|
| 31 |
+
|
| 32 |
+
# reading the images
|
| 33 |
+
image, jewellery = image.convert("RGB").resize((3000, 3000)), jewellery.convert("RGBA")
|
| 34 |
+
image = np.array(image)
|
| 35 |
+
copy_image = image.copy()
|
| 36 |
+
jewellery = np.array(jewellery)
|
| 37 |
+
|
| 38 |
+
logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
|
| 39 |
+
|
| 40 |
+
image = self.detector.findPose(image)
|
| 41 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
| 42 |
+
|
| 43 |
+
img, faces = self.meshDetector.findFaceMesh(image, draw=False)
|
| 44 |
+
leftLandmarkIndex = 172
|
| 45 |
+
rightLandmarkIndex = 397
|
| 46 |
+
|
| 47 |
+
leftLandmark, rightLandmark = faces[0][leftLandmarkIndex], faces[0][rightLandmarkIndex]
|
| 48 |
+
landmarksDistance = int(
|
| 49 |
+
((leftLandmark[0] - rightLandmark[0]) ** 2 + (leftLandmark[1] - rightLandmark[1]) ** 2) ** 0.5)
|
| 50 |
+
|
| 51 |
+
logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
|
| 52 |
+
|
| 53 |
+
# avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.15) -> V2.1
|
| 54 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
|
| 55 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
|
| 56 |
+
|
| 57 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
|
| 58 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
|
| 59 |
+
|
| 60 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
|
| 61 |
+
|
| 62 |
+
if avg_y2 < avg_y1:
|
| 63 |
+
angle = math.ceil(
|
| 64 |
+
self.detector.findAngle(
|
| 65 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
| 66 |
+
)[0]
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
angle = math.ceil(
|
| 70 |
+
self.detector.findAngle(
|
| 71 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
| 72 |
+
)[0]
|
| 73 |
+
)
|
| 74 |
+
angle = angle * -1
|
| 75 |
+
|
| 76 |
+
xdist = avg_x2 - avg_x1
|
| 77 |
+
origImgRatio = xdist / jewellery.shape[1]
|
| 78 |
+
ydist = jewellery.shape[0] * origImgRatio
|
| 79 |
+
|
| 80 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
| 81 |
+
|
| 82 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
| 83 |
+
for offset_orig in range(image_gray.shape[1]):
|
| 84 |
+
pixel_value = image_gray[0, :][offset_orig]
|
| 85 |
+
if (pixel_value != 255) & (pixel_value != 0):
|
| 86 |
+
break
|
| 87 |
+
else:
|
| 88 |
+
continue
|
| 89 |
+
offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
|
| 90 |
+
jewellery = cv2.resize(
|
| 91 |
+
jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
|
| 92 |
+
)
|
| 93 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
| 94 |
+
y_coordinate = avg_y1 - offset
|
| 95 |
+
available_space = copy_image.shape[0] - y_coordinate
|
| 96 |
+
extra = jewellery.shape[0] - available_space
|
| 97 |
+
|
| 98 |
+
logger.info(f"NECKLACE TRY ON :: generating necklace placement status :: {storename}")
|
| 99 |
+
|
| 100 |
+
if extra > 0:
|
| 101 |
+
headerText = "To see more of the necklace, please step back slightly."
|
| 102 |
+
else:
|
| 103 |
+
headerText = "success"
|
| 104 |
+
|
| 105 |
+
logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
|
| 106 |
+
|
| 107 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
| 108 |
+
image = Image.fromarray(result.astype(np.uint8))
|
| 109 |
+
logo = Image.open(returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
|
| 110 |
+
result = addWatermark(background=image, logo=logo)
|
| 111 |
+
|
| 112 |
+
gc.collect()
|
| 113 |
+
|
| 114 |
+
return [result, headerText]
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
| 118 |
+
raise CustomException(e)
|
| 119 |
+
|
| 120 |
+
def necklaceTryOnV3(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
| 121 |
+
Union[Image.Image, str]]:
|
| 122 |
+
try:
|
| 123 |
+
logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
|
| 124 |
+
|
| 125 |
+
# reading the images
|
| 126 |
+
image, jewellery = image.convert("RGB"), jewellery.convert("RGBA")
|
| 127 |
+
image = np.array(image.resize((4000, 4000)))
|
| 128 |
+
copy_image = image.copy()
|
| 129 |
+
jewellery = np.array(jewellery)
|
| 130 |
+
|
| 131 |
+
logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
|
| 132 |
+
|
| 133 |
+
image = self.detector.findPose(image)
|
| 134 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
| 135 |
+
meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
| 136 |
+
img, faces = meshDetector.findFaceMesh(image, draw=False)
|
| 137 |
+
left_lip_point = faces[0][61]
|
| 138 |
+
right_lip_point = faces[0][291]
|
| 139 |
+
pt12, pt11, pt10, pt9 = (
|
| 140 |
+
lmList[12][:2],
|
| 141 |
+
lmList[11][:2],
|
| 142 |
+
lmList[10][:2],
|
| 143 |
+
lmList[9][:2],
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
mid_lips = (
|
| 147 |
+
int((left_lip_point[0] + right_lip_point[0]) / 2), int((left_lip_point[1] + right_lip_point[1]) / 2))
|
| 148 |
+
|
| 149 |
+
mid_lips_x1 = int(pt12[0] + (mid_lips[0] - pt12[0]) / 2)
|
| 150 |
+
mid_lips_y1 = int(pt12[1] + (mid_lips[1] - pt12[1]) / 2)
|
| 151 |
+
|
| 152 |
+
mid_lips_x2 = int(pt11[0] + (mid_lips[0] - pt11[0]) / 2)
|
| 153 |
+
mid_lips_y2 = int(pt11[1] + (mid_lips[1] - pt11[1]) / 2)
|
| 154 |
+
|
| 155 |
+
# left right lip
|
| 156 |
+
left_right_lip_org_x11 = int(pt12[0] + (right_lip_point[0] - pt12[0]) / 2)
|
| 157 |
+
left_right_lip_org_y11 = int(pt12[1] + (right_lip_point[1] - pt12[1]) / 2)
|
| 158 |
+
|
| 159 |
+
left_right_lip_org_x12 = int(pt11[0] + (left_lip_point[0] - pt11[0]) / 2)
|
| 160 |
+
left_right_lip_org_y12 = int(pt11[1] + (left_lip_point[1] - pt11[1]) / 2)
|
| 161 |
+
|
| 162 |
+
# left right lip 2
|
| 163 |
+
left_right_lip_org_x21 = int(pt12[0] + (left_lip_point[0] - pt12[0]) / 2)
|
| 164 |
+
left_right_lip_org_y21 = int(pt12[1] + (left_lip_point[1] - pt12[1]) / 2)
|
| 165 |
+
|
| 166 |
+
left_right_lip_org_x22 = int(pt11[0] + (right_lip_point[0] - pt11[0]) / 2)
|
| 167 |
+
left_right_lip_org_y22 = int(pt11[1] + (right_lip_point[1] - pt11[1]) / 2)
|
| 168 |
+
|
| 169 |
+
logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
|
| 170 |
+
|
| 171 |
+
avg_x1 = int((mid_lips_x1 + left_right_lip_org_x11 + left_right_lip_org_x21) / 3)
|
| 172 |
+
avg_y1 = int((mid_lips_y1 + left_right_lip_org_y11 + left_right_lip_org_y21) / 3)
|
| 173 |
+
|
| 174 |
+
avg_x2 = int((mid_lips_x2 + left_right_lip_org_x12 + left_right_lip_org_x22) / 3)
|
| 175 |
+
avg_y2 = int((mid_lips_y2 + left_right_lip_org_y12 + left_right_lip_org_y22) / 3)
|
| 176 |
+
|
| 177 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
|
| 178 |
+
|
| 179 |
+
if avg_y2 < avg_y1:
|
| 180 |
+
angle = math.ceil(
|
| 181 |
+
self.detector.findAngle(
|
| 182 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
| 183 |
+
)[0]
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
angle = math.ceil(
|
| 187 |
+
self.detector.findAngle(
|
| 188 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
| 189 |
+
)[0]
|
| 190 |
+
)
|
| 191 |
+
angle = angle * -1
|
| 192 |
+
|
| 193 |
+
xdist = avg_x2 - avg_x1
|
| 194 |
+
origImgRatio = xdist / jewellery.shape[1]
|
| 195 |
+
ydist = jewellery.shape[0] * origImgRatio
|
| 196 |
+
|
| 197 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
| 198 |
+
|
| 199 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
| 200 |
+
for offset_orig in range(image_gray.shape[1]):
|
| 201 |
+
pixel_value = image_gray[0, :][offset_orig]
|
| 202 |
+
if (pixel_value != 255) & (pixel_value != 0):
|
| 203 |
+
break
|
| 204 |
+
else:
|
| 205 |
+
continue
|
| 206 |
+
offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
|
| 207 |
+
jewellery = cv2.resize(
|
| 208 |
+
jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
|
| 209 |
+
)
|
| 210 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
| 211 |
+
y_coordinate = avg_y1 - offset
|
| 212 |
+
available_space = copy_image.shape[0] - y_coordinate
|
| 213 |
+
extra = jewellery.shape[0] - available_space
|
| 214 |
+
|
| 215 |
+
logger.info(f"NECKLACE TRY ON :: generating necklace placement status :: {storename}")
|
| 216 |
+
|
| 217 |
+
if extra > 0:
|
| 218 |
+
headerText = "To see more of the necklace, please step back slightly."
|
| 219 |
+
else:
|
| 220 |
+
headerText = "success"
|
| 221 |
+
|
| 222 |
+
logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
|
| 223 |
+
|
| 224 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
| 225 |
+
image = Image.fromarray(result.astype(np.uint8))
|
| 226 |
+
logo = Image.open(returnBytesData(url=self.necklaceTryOnConfig.logoURL.format({storename})))
|
| 227 |
+
result = addWatermark(background=image, logo=logo)
|
| 228 |
+
|
| 229 |
+
gc.collect()
|
| 230 |
+
|
| 231 |
+
return [result, headerText]
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
| 235 |
+
raise CustomException(e)
|
src/pipelines/__init__.py
ADDED
|
File without changes
|
src/pipelines/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
src/pipelines/__pycache__/completePipeline.cpython-310.pyc
ADDED
|
Binary file (925 Bytes). View file
|
|
|
src/pipelines/completePipeline.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.components.necklaceTryOn import NecklaceTryOn
|
| 2 |
+
from typing import Union
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Pipeline:
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
self.necklaceTryOnObj = NecklaceTryOn()
|
| 9 |
+
|
| 10 |
+
async def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
| 11 |
+
Union[Image.Image, str]]:
|
| 12 |
+
result, headerText = self.necklaceTryOnObj.necklaceTryOn(image=image, jewellery=jewellery, storename=storename)
|
| 13 |
+
return [result, headerText]
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from supabase import create_client, Client
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# function to add watermark to images
|
| 9 |
+
def addWatermark(background: Image.Image, logo: Image.Image) -> Image.Image:
|
| 10 |
+
background = background.convert("RGBA")
|
| 11 |
+
logo = logo.convert("RGBA").resize((int(0.08 * background.size[0]), int(0.08 * background.size[0])))
|
| 12 |
+
background.paste(logo, (10, background.size[1] - logo.size[1] - 10), logo)
|
| 13 |
+
return background
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# function to download an image from url and return as bytes objects
|
| 17 |
+
def returnBytesData(url: str) -> BytesIO:
|
| 18 |
+
response = requests.get(url)
|
| 19 |
+
return BytesIO(response.content)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# function to get public URLs of paths
|
| 23 |
+
def supabaseGetPublicURL(path: str) -> str:
|
| 24 |
+
url_string = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{path}"
|
| 25 |
+
return url_string.replace(" ", "%20")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# function to deduct credit
|
| 29 |
+
def deductAndTrackCredit(storename: str, endpoint: str) -> str:
|
| 30 |
+
url: str = os.environ["SUPABASE_URL"]
|
| 31 |
+
key: str = os.environ["SUPABASE_KEY"]
|
| 32 |
+
supabase: Client = create_client(url, key)
|
| 33 |
+
current, _ = supabase.table('ClientConfig').select('CreditBalance').eq("StoreName", f"{storename}").execute()
|
| 34 |
+
if current[1] == []:
|
| 35 |
+
return "Not Found"
|
| 36 |
+
else:
|
| 37 |
+
current = current[1][0]["CreditBalance"]
|
| 38 |
+
if current > 0:
|
| 39 |
+
data, _ = supabase.table('ClientConfig').update({'CreditBalance': current - 1}).eq("StoreName",
|
| 40 |
+
f"{storename}").execute()
|
| 41 |
+
data, _ = supabase.table('UsageHistory').insert(
|
| 42 |
+
{'StoreName': f"{storename}", 'APIEndpoint': f"{endpoint}"}).execute()
|
| 43 |
+
return "Success"
|
| 44 |
+
else:
|
| 45 |
+
return "No Credits Available"
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
src/utils/__pycache__/backgroundEnhancerArchitecture.cpython-310.pyc
ADDED
|
Binary file (9.92 kB). View file
|
|
|
src/utils/__pycache__/exceptions.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
src/utils/__pycache__/logger.cpython-310.pyc
ADDED
|
Binary file (683 Bytes). View file
|
|
|
src/utils/backgroundEnhancerArchitecture.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class REBNCONV(nn.Module):
|
| 8 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
| 9 |
+
super(REBNCONV, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
|
| 12 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
| 13 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
hx = x
|
| 17 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
| 18 |
+
|
| 19 |
+
return xout
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
| 23 |
+
def _upsample_like(src, tar):
|
| 24 |
+
src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
|
| 25 |
+
|
| 26 |
+
return src
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
### RSU-7 ###
|
| 30 |
+
class RSU7(nn.Module):
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
| 33 |
+
super(RSU7, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.in_ch = in_ch
|
| 36 |
+
self.mid_ch = mid_ch
|
| 37 |
+
self.out_ch = out_ch
|
| 38 |
+
|
| 39 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
| 40 |
+
|
| 41 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 42 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 43 |
+
|
| 44 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 45 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 46 |
+
|
| 47 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 48 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 49 |
+
|
| 50 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 51 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 52 |
+
|
| 53 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 54 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 55 |
+
|
| 56 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 57 |
+
|
| 58 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 59 |
+
|
| 60 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 61 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 62 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 63 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 64 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 65 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
b, c, h, w = x.shape
|
| 69 |
+
|
| 70 |
+
hx = x
|
| 71 |
+
hxin = self.rebnconvin(hx)
|
| 72 |
+
|
| 73 |
+
hx1 = self.rebnconv1(hxin)
|
| 74 |
+
hx = self.pool1(hx1)
|
| 75 |
+
|
| 76 |
+
hx2 = self.rebnconv2(hx)
|
| 77 |
+
hx = self.pool2(hx2)
|
| 78 |
+
|
| 79 |
+
hx3 = self.rebnconv3(hx)
|
| 80 |
+
hx = self.pool3(hx3)
|
| 81 |
+
|
| 82 |
+
hx4 = self.rebnconv4(hx)
|
| 83 |
+
hx = self.pool4(hx4)
|
| 84 |
+
|
| 85 |
+
hx5 = self.rebnconv5(hx)
|
| 86 |
+
hx = self.pool5(hx5)
|
| 87 |
+
|
| 88 |
+
hx6 = self.rebnconv6(hx)
|
| 89 |
+
|
| 90 |
+
hx7 = self.rebnconv7(hx6)
|
| 91 |
+
|
| 92 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
| 93 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
| 94 |
+
|
| 95 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
| 96 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 97 |
+
|
| 98 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 99 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 100 |
+
|
| 101 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 102 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 103 |
+
|
| 104 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 105 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 106 |
+
|
| 107 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 108 |
+
|
| 109 |
+
return hx1d + hxin
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
### RSU-6 ###
|
| 113 |
+
class RSU6(nn.Module):
|
| 114 |
+
|
| 115 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 116 |
+
super(RSU6, self).__init__()
|
| 117 |
+
|
| 118 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 119 |
+
|
| 120 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 121 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 122 |
+
|
| 123 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 124 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 125 |
+
|
| 126 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 127 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 128 |
+
|
| 129 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 130 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 131 |
+
|
| 132 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 133 |
+
|
| 134 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 135 |
+
|
| 136 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 137 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 138 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 139 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 140 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
hx = x
|
| 144 |
+
|
| 145 |
+
hxin = self.rebnconvin(hx)
|
| 146 |
+
|
| 147 |
+
hx1 = self.rebnconv1(hxin)
|
| 148 |
+
hx = self.pool1(hx1)
|
| 149 |
+
|
| 150 |
+
hx2 = self.rebnconv2(hx)
|
| 151 |
+
hx = self.pool2(hx2)
|
| 152 |
+
|
| 153 |
+
hx3 = self.rebnconv3(hx)
|
| 154 |
+
hx = self.pool3(hx3)
|
| 155 |
+
|
| 156 |
+
hx4 = self.rebnconv4(hx)
|
| 157 |
+
hx = self.pool4(hx4)
|
| 158 |
+
|
| 159 |
+
hx5 = self.rebnconv5(hx)
|
| 160 |
+
|
| 161 |
+
hx6 = self.rebnconv6(hx5)
|
| 162 |
+
|
| 163 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
| 164 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 165 |
+
|
| 166 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 167 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 168 |
+
|
| 169 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 170 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 171 |
+
|
| 172 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 173 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 174 |
+
|
| 175 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 176 |
+
|
| 177 |
+
return hx1d + hxin
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
### RSU-5 ###
|
| 181 |
+
class RSU5(nn.Module):
|
| 182 |
+
|
| 183 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 184 |
+
super(RSU5, self).__init__()
|
| 185 |
+
|
| 186 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 187 |
+
|
| 188 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 189 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 190 |
+
|
| 191 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 192 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 193 |
+
|
| 194 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 195 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 196 |
+
|
| 197 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 198 |
+
|
| 199 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 200 |
+
|
| 201 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 202 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 203 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 204 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
hx = x
|
| 208 |
+
|
| 209 |
+
hxin = self.rebnconvin(hx)
|
| 210 |
+
|
| 211 |
+
hx1 = self.rebnconv1(hxin)
|
| 212 |
+
hx = self.pool1(hx1)
|
| 213 |
+
|
| 214 |
+
hx2 = self.rebnconv2(hx)
|
| 215 |
+
hx = self.pool2(hx2)
|
| 216 |
+
|
| 217 |
+
hx3 = self.rebnconv3(hx)
|
| 218 |
+
hx = self.pool3(hx3)
|
| 219 |
+
|
| 220 |
+
hx4 = self.rebnconv4(hx)
|
| 221 |
+
|
| 222 |
+
hx5 = self.rebnconv5(hx4)
|
| 223 |
+
|
| 224 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
| 225 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 226 |
+
|
| 227 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 228 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 229 |
+
|
| 230 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 231 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 232 |
+
|
| 233 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 234 |
+
|
| 235 |
+
return hx1d + hxin
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
### RSU-4 ###
|
| 239 |
+
class RSU4(nn.Module):
|
| 240 |
+
|
| 241 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 242 |
+
super(RSU4, self).__init__()
|
| 243 |
+
|
| 244 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 245 |
+
|
| 246 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 247 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 248 |
+
|
| 249 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 250 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 251 |
+
|
| 252 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 253 |
+
|
| 254 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 255 |
+
|
| 256 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 257 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 258 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
hx = x
|
| 262 |
+
|
| 263 |
+
hxin = self.rebnconvin(hx)
|
| 264 |
+
|
| 265 |
+
hx1 = self.rebnconv1(hxin)
|
| 266 |
+
hx = self.pool1(hx1)
|
| 267 |
+
|
| 268 |
+
hx2 = self.rebnconv2(hx)
|
| 269 |
+
hx = self.pool2(hx2)
|
| 270 |
+
|
| 271 |
+
hx3 = self.rebnconv3(hx)
|
| 272 |
+
|
| 273 |
+
hx4 = self.rebnconv4(hx3)
|
| 274 |
+
|
| 275 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 276 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 277 |
+
|
| 278 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 279 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 280 |
+
|
| 281 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 282 |
+
|
| 283 |
+
return hx1d + hxin
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
### RSU-4F ###
|
| 287 |
+
class RSU4F(nn.Module):
|
| 288 |
+
|
| 289 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 290 |
+
super(RSU4F, self).__init__()
|
| 291 |
+
|
| 292 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 293 |
+
|
| 294 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 295 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 296 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
| 297 |
+
|
| 298 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
| 299 |
+
|
| 300 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
| 301 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
| 302 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 303 |
+
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
hx = x
|
| 306 |
+
|
| 307 |
+
hxin = self.rebnconvin(hx)
|
| 308 |
+
|
| 309 |
+
hx1 = self.rebnconv1(hxin)
|
| 310 |
+
hx2 = self.rebnconv2(hx1)
|
| 311 |
+
hx3 = self.rebnconv3(hx2)
|
| 312 |
+
|
| 313 |
+
hx4 = self.rebnconv4(hx3)
|
| 314 |
+
|
| 315 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 316 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
| 317 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
| 318 |
+
|
| 319 |
+
return hx1d + hxin
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class myrebnconv(nn.Module):
|
| 323 |
+
def __init__(self, in_ch=3,
|
| 324 |
+
out_ch=1,
|
| 325 |
+
kernel_size=3,
|
| 326 |
+
stride=1,
|
| 327 |
+
padding=1,
|
| 328 |
+
dilation=1,
|
| 329 |
+
groups=1):
|
| 330 |
+
super(myrebnconv, self).__init__()
|
| 331 |
+
|
| 332 |
+
self.conv = nn.Conv2d(in_ch,
|
| 333 |
+
out_ch,
|
| 334 |
+
kernel_size=kernel_size,
|
| 335 |
+
stride=stride,
|
| 336 |
+
padding=padding,
|
| 337 |
+
dilation=dilation,
|
| 338 |
+
groups=groups)
|
| 339 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 340 |
+
self.rl = nn.ReLU(inplace=True)
|
| 341 |
+
|
| 342 |
+
def forward(self, x):
|
| 343 |
+
return self.rl(self.bn(self.conv(x)))
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class BackgroundEnhancerArchitecture(nn.Module, PyTorchModelHubMixin):
|
| 347 |
+
|
| 348 |
+
def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
|
| 349 |
+
super(BackgroundEnhancerArchitecture, self).__init__()
|
| 350 |
+
in_ch = config["in_ch"]
|
| 351 |
+
out_ch = config["out_ch"]
|
| 352 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
| 353 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 354 |
+
|
| 355 |
+
self.stage1 = RSU7(64, 32, 64)
|
| 356 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 357 |
+
|
| 358 |
+
self.stage2 = RSU6(64, 32, 128)
|
| 359 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 360 |
+
|
| 361 |
+
self.stage3 = RSU5(128, 64, 256)
|
| 362 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 363 |
+
|
| 364 |
+
self.stage4 = RSU4(256, 128, 512)
|
| 365 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 366 |
+
|
| 367 |
+
self.stage5 = RSU4F(512, 256, 512)
|
| 368 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 369 |
+
|
| 370 |
+
self.stage6 = RSU4F(512, 256, 512)
|
| 371 |
+
|
| 372 |
+
# decoder
|
| 373 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
| 374 |
+
self.stage4d = RSU4(1024, 128, 256)
|
| 375 |
+
self.stage3d = RSU5(512, 64, 128)
|
| 376 |
+
self.stage2d = RSU6(256, 32, 64)
|
| 377 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 378 |
+
|
| 379 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 380 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 381 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
| 382 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
| 383 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 384 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 385 |
+
|
| 386 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
| 387 |
+
|
| 388 |
+
def forward(self, x):
|
| 389 |
+
hx = x
|
| 390 |
+
|
| 391 |
+
hxin = self.conv_in(hx)
|
| 392 |
+
# hx = self.pool_in(hxin)
|
| 393 |
+
|
| 394 |
+
# stage 1
|
| 395 |
+
hx1 = self.stage1(hxin)
|
| 396 |
+
hx = self.pool12(hx1)
|
| 397 |
+
|
| 398 |
+
# stage 2
|
| 399 |
+
hx2 = self.stage2(hx)
|
| 400 |
+
hx = self.pool23(hx2)
|
| 401 |
+
|
| 402 |
+
# stage 3
|
| 403 |
+
hx3 = self.stage3(hx)
|
| 404 |
+
hx = self.pool34(hx3)
|
| 405 |
+
|
| 406 |
+
# stage 4
|
| 407 |
+
hx4 = self.stage4(hx)
|
| 408 |
+
hx = self.pool45(hx4)
|
| 409 |
+
|
| 410 |
+
# stage 5
|
| 411 |
+
hx5 = self.stage5(hx)
|
| 412 |
+
hx = self.pool56(hx5)
|
| 413 |
+
|
| 414 |
+
# stage 6
|
| 415 |
+
hx6 = self.stage6(hx)
|
| 416 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 417 |
+
|
| 418 |
+
# -------------------- decoder --------------------
|
| 419 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 420 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 421 |
+
|
| 422 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 423 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 424 |
+
|
| 425 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 426 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 427 |
+
|
| 428 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 429 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 430 |
+
|
| 431 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 432 |
+
|
| 433 |
+
# side output
|
| 434 |
+
d1 = self.side1(hx1d)
|
| 435 |
+
d1 = _upsample_like(d1, x)
|
| 436 |
+
|
| 437 |
+
d2 = self.side2(hx2d)
|
| 438 |
+
d2 = _upsample_like(d2, x)
|
| 439 |
+
|
| 440 |
+
d3 = self.side3(hx3d)
|
| 441 |
+
d3 = _upsample_like(d3, x)
|
| 442 |
+
|
| 443 |
+
d4 = self.side4(hx4d)
|
| 444 |
+
d4 = _upsample_like(d4, x)
|
| 445 |
+
|
| 446 |
+
d5 = self.side5(hx5d)
|
| 447 |
+
d5 = _upsample_like(d5, x)
|
| 448 |
+
|
| 449 |
+
d6 = self.side6(hx6)
|
| 450 |
+
d6 = _upsample_like(d6, x)
|
| 451 |
+
|
| 452 |
+
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
|
| 453 |
+
hx3d, hx4d,
|
| 454 |
+
hx5d, hx6]
|
src/utils/exceptions.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
def error_message_detail(error):
|
| 4 |
+
_, _, exc_info = sys.exc_info()
|
| 5 |
+
filename = exc_info.tb_frame.f_code.co_filename
|
| 6 |
+
lineno = exc_info.tb_lineno
|
| 7 |
+
error_message = "Error encountered in line no [{}], filename : [{}], saying [{}]".format(lineno, filename, error)
|
| 8 |
+
return error_message
|
| 9 |
+
|
| 10 |
+
class CustomException(Exception):
|
| 11 |
+
def __init__(self, error_message):
|
| 12 |
+
super().__init__(error_message)
|
| 13 |
+
self.error_message = error_message_detail(error_message)
|
| 14 |
+
|
| 15 |
+
def __str__(self) -> str:
|
| 16 |
+
return self.error_message
|
src/utils/logger.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
logger.setLevel(logging.INFO)
|
| 6 |
+
|
| 7 |
+
log_dir = os.path.join(os.getcwd(), "logs")
|
| 8 |
+
os.makedirs(log_dir, exist_ok = True)
|
| 9 |
+
|
| 10 |
+
LOG_FILE = os.path.join(log_dir, "running_logs.log")
|
| 11 |
+
|
| 12 |
+
logFormat = "[%(asctime)s: %(levelname)s: %(module)s: %(message)s]"
|
| 13 |
+
logFormatter = logging.Formatter(fmt = logFormat, style = "%")
|
| 14 |
+
|
| 15 |
+
streamHandler = logging.StreamHandler()
|
| 16 |
+
streamHandler.setFormatter(logFormatter)
|
| 17 |
+
|
| 18 |
+
fileHandler = logging.FileHandler(filename = LOG_FILE)
|
| 19 |
+
fileHandler.setFormatter(logFormatter)
|
| 20 |
+
|
| 21 |
+
logger.addHandler(streamHandler)
|
| 22 |
+
logger.addHandler(fileHandler)
|