HariLogicgo commited on
Commit
861422e
·
1 Parent(s): 5517f63

new gemini api

Browse files
Files changed (10) hide show
  1. API_Usage_Guide.md +8 -4
  2. HTTP_API_Documentation.txt +10 -6
  3. README.md +0 -0
  4. api/main.py +93 -56
  5. app.py +0 -2
  6. assets/big-lama.pt +0 -3
  7. requirements.txt +8 -26
  8. src/core.py +196 -536
  9. src/helper.py +0 -87
  10. src/st_style.py +0 -42
API_Usage_Guide.md CHANGED
@@ -4,7 +4,9 @@
4
  This guide provides step-by-step instructions for using the Photo Object Removal API to remove objects from images using AI inpainting.
5
 
6
  **Base URL:** `https://logicgoinfotechspaces-object-remover.hf.space`
7
- **Authentication:** Bearer token (optional)
 
 
8
 
9
  ## Quick Start
10
 
@@ -39,7 +41,7 @@ Remove objects using the uploaded image and mask:
39
  ```bash
40
  curl -H "Authorization: Bearer <API_TOKEN>" \
41
  -H "Content-Type: application/json" \
42
- -d '{"image_id":"9cf61445-f83b-4c97-9272-c81647f90d68","mask_id":"d044a390-dde2-408a-b7cf-d508385e56ed"}' \
43
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint
44
  ```
45
  **Response:** `{"result":"output_b09568698bbd4aa591b1598c01f2f745.png"}`
@@ -57,7 +59,7 @@ Use `/inpaint-url` to get a shareable URL:
57
  ```bash
58
  curl -H "Authorization: Bearer <API_TOKEN>" \
59
  -H "Content-Type: application/json" \
60
- -d '{"image_id":"<image_id>","mask_id":"<mask_id>"}' \
61
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint-url
62
  ```
63
  **Response:**
@@ -74,6 +76,7 @@ Upload and process in a single request:
74
  curl -H "Authorization: Bearer <API_TOKEN>" \
75
  -F image=@image.jpg \
76
  -F mask=@mask.jpg \
 
77
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint-multipart
78
  ```
79
 
@@ -112,7 +115,8 @@ curl -L https://logicgoinfotechspaces-object-remover.hf.space/download/output_xx
112
  ```json
113
  {
114
  "image_id": "9cf61445-f83b-4c97-9272-c81647f90d68",
115
- "mask_id": "d044a390-dde2-408a-b7cf-d508385e56ed"
 
116
  }
117
  ```
118
 
 
4
  This guide provides step-by-step instructions for using the Photo Object Removal API to remove objects from images using AI inpainting.
5
 
6
  **Base URL:** `https://logicgoinfotechspaces-object-remover.hf.space`
7
+ **Authentication:** Bearer token (optional)
8
+ **Storage:** Uploaded images/masks are saved in MongoDB GridFS (database `object_remover`); IDs returned by upload endpoints are pulled from GridFS before sending to Gemini.
9
+ - Processing is delegated to Google Gemini/Imagen edit API; only lightweight CPU work (mask prep, file IO) happens on this server.
10
 
11
  ## Quick Start
12
 
 
41
  ```bash
42
  curl -H "Authorization: Bearer <API_TOKEN>" \
43
  -H "Content-Type: application/json" \
44
+ -d '{"image_id":"9cf61445-f83b-4c97-9272-c81647f90d68","mask_id":"d044a390-dde2-408a-b7cf-d508385e56ed","prompt":"Describe what should be removed"}' \
45
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint
46
  ```
47
  **Response:** `{"result":"output_b09568698bbd4aa591b1598c01f2f745.png"}`
 
59
  ```bash
60
  curl -H "Authorization: Bearer <API_TOKEN>" \
61
  -H "Content-Type: application/json" \
62
+ -d '{"image_id":"<image_id>","mask_id":"<mask_id>","prompt":"Describe what should be removed"}' \
63
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint-url
64
  ```
65
  **Response:**
 
76
  curl -H "Authorization: Bearer <API_TOKEN>" \
77
  -F image=@image.jpg \
78
  -F mask=@mask.jpg \
79
+ -F prompt="Describe what should be removed" \
80
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint-multipart
81
  ```
82
 
 
115
  ```json
116
  {
117
  "image_id": "9cf61445-f83b-4c97-9272-c81647f90d68",
118
+ "mask_id": "d044a390-dde2-408a-b7cf-d508385e56ed",
119
+ "prompt": "Describe what should be removed"
120
  }
121
  ```
122
 
HTTP_API_Documentation.txt CHANGED
@@ -7,6 +7,8 @@ Authentication:
7
  - Set API_TOKEN environment variable on server to enable auth
8
  - Send header: Authorization: Bearer <API_TOKEN>
9
  - If API_TOKEN not set, all endpoints are publicly accessible
 
 
10
 
11
  Available Endpoints:
12
 
@@ -34,21 +36,21 @@ Available Endpoints:
34
  4. POST /inpaint
35
  - Process inpainting using uploaded image and mask IDs
36
  - Content-Type: application/json
37
- - Body: {"image_id":"<image_id>","mask_id":"<mask_id>"}
38
  - Returns: {"result":"output_xxx.png"}
39
  - Simple response with just the filename
40
 
41
  5. POST /inpaint-url
42
  - Same as /inpaint but returns JSON with public download URL
43
  - Content-Type: application/json
44
- - Body: {"image_id":"<image_id>","mask_id":"<mask_id>"}
45
  - Returns: {"result":"output_xxx.png","url":"https://.../download/output_xxx.png"}
46
  - Use this endpoint if you need a shareable URL
47
 
48
  6. POST /inpaint-multipart
49
  - Process inpainting with direct file upload (no separate upload steps)
50
  - Content-Type: multipart/form-data
51
- - Form fields: image (file), mask (file)
52
  - Returns: {"result":"output_xxx.png","url":"https://.../download/output_xxx.png"}
53
 
54
  7. GET /download/{filename}
@@ -84,14 +86,14 @@ curl -H "Authorization: Bearer <API_TOKEN>" \
84
  4. Inpaint (returns simple JSON):
85
  curl -H "Authorization: Bearer <API_TOKEN>" \
86
  -H "Content-Type: application/json" \
87
- -d '{"image_id":"9cf61445-f83b-4c97-9272-c81647f90d68","mask_id":"d044a390-dde2-408a-b7cf-d508385e56ed"}' \
88
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint
89
  # Response: {"result":"output_b09568698bbd4aa591b1598c01f2f745.png"}
90
 
91
  5. Inpaint-URL (returns JSON with public URL):
92
  curl -H "Authorization: Bearer <API_TOKEN>" \
93
  -H "Content-Type: application/json" \
94
- -d '{"image_id":"9cf61445-f83b-4c97-9272-c81647f90d68","mask_id":"d044a390-dde2-408a-b7cf-d508385e56ed"}' \
95
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint-url
96
  # Response: {"result":"output_b09568698bbd4aa591b1598c01f2f745.png","url":"https://logicgoinfotechspaces-object-remover.hf.space/download/output_b09568698bbd4aa591b1598c01f2f745.png"}
97
 
@@ -140,7 +142,8 @@ POSTMAN EXAMPLES:
140
  Body: raw JSON
141
  {
142
  "image_id": "9cf61445-f83b-4c97-9272-c81647f90d68",
143
- "mask_id": "d044a390-dde2-408a-b7cf-d508385e56ed"
 
144
  }
145
 
146
  5. Inpaint Multipart (one-step):
@@ -150,6 +153,7 @@ POSTMAN EXAMPLES:
150
  Body: form-data
151
  Key: image, Type: File, Value: select your image file
152
  Key: mask, Type: File, Value: select your mask file
 
153
 
154
  IMPORTANT NOTES:
155
 
 
7
  - Set API_TOKEN environment variable on server to enable auth
8
  - Send header: Authorization: Bearer <API_TOKEN>
9
  - If API_TOKEN not set, all endpoints are publicly accessible
10
+ - Inpainting work is delegated to Google Gemini/Imagen edit API; no GPU is needed on the server.
11
+ - Uploads are stored in MongoDB GridFS (database `object_remover`); the IDs returned by upload endpoints are fetched from GridFS when processing.
12
 
13
  Available Endpoints:
14
 
 
36
  4. POST /inpaint
37
  - Process inpainting using uploaded image and mask IDs
38
  - Content-Type: application/json
39
+ - Body: {"image_id":"<image_id>","mask_id":"<mask_id>","prompt":"optional text about what to remove"}
40
  - Returns: {"result":"output_xxx.png"}
41
  - Simple response with just the filename
42
 
43
  5. POST /inpaint-url
44
  - Same as /inpaint but returns JSON with public download URL
45
  - Content-Type: application/json
46
+ - Body: {"image_id":"<image_id>","mask_id":"<mask_id>","prompt":"optional text about what to remove"}
47
  - Returns: {"result":"output_xxx.png","url":"https://.../download/output_xxx.png"}
48
  - Use this endpoint if you need a shareable URL
49
 
50
  6. POST /inpaint-multipart
51
  - Process inpainting with direct file upload (no separate upload steps)
52
  - Content-Type: multipart/form-data
53
+ - Form fields: image (file), mask (file), prompt (optional text)
54
  - Returns: {"result":"output_xxx.png","url":"https://.../download/output_xxx.png"}
55
 
56
  7. GET /download/{filename}
 
86
  4. Inpaint (returns simple JSON):
87
  curl -H "Authorization: Bearer <API_TOKEN>" \
88
  -H "Content-Type: application/json" \
89
+ -d '{"image_id":"9cf61445-f83b-4c97-9272-c81647f90d68","mask_id":"d044a390-dde2-408a-b7cf-d508385e56ed","prompt":"Remove the car and repair the road"}' \
90
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint
91
  # Response: {"result":"output_b09568698bbd4aa591b1598c01f2f745.png"}
92
 
93
  5. Inpaint-URL (returns JSON with public URL):
94
  curl -H "Authorization: Bearer <API_TOKEN>" \
95
  -H "Content-Type: application/json" \
96
+ -d '{"image_id":"9cf61445-f83b-4c97-9272-c81647f90d68","mask_id":"d044a390-dde2-408a-b7cf-d508385e56ed","prompt":"Remove the car and repair the road"}' \
97
  https://logicgoinfotechspaces-object-remover.hf.space/inpaint-url
98
  # Response: {"result":"output_b09568698bbd4aa591b1598c01f2f745.png","url":"https://logicgoinfotechspaces-object-remover.hf.space/download/output_b09568698bbd4aa591b1598c01f2f745.png"}
99
 
 
142
  Body: raw JSON
143
  {
144
  "image_id": "9cf61445-f83b-4c97-9272-c81647f90d68",
145
+ "mask_id": "d044a390-dde2-408a-b7cf-d508385e56ed",
146
+ "prompt": "Remove the car and restore the street"
147
  }
148
 
149
  5. Inpaint Multipart (one-step):
 
153
  Body: form-data
154
  Key: image, Type: File, Value: select your image file
155
  Key: mask, Type: File, Value: select your mask file
156
+ Key: prompt, Type: Text, Value: describe what to remove (optional)
157
 
158
  IMPORTANT NOTES:
159
 
README.md CHANGED
Binary files a/README.md and b/README.md differ
 
api/main.py CHANGED
@@ -4,6 +4,7 @@ import uuid
4
  import shutil
5
  import re
6
  from datetime import datetime, timedelta, date
 
7
  from typing import Dict, List, Optional
8
 
9
  import numpy as np
@@ -19,14 +20,24 @@ from fastapi import (
19
  )
20
  from fastapi.responses import FileResponse, JSONResponse
21
  from pydantic import BaseModel
22
- from PIL import Image
23
  import cv2
24
  import logging
 
 
25
 
26
  from bson import ObjectId
27
  from pymongo import MongoClient
28
  import time
29
 
 
 
 
 
 
 
 
 
30
  logging.basicConfig(level=logging.INFO)
31
  log = logging.getLogger("api")
32
 
@@ -53,10 +64,15 @@ app = FastAPI(title="Photo Object Removal API", version="1.0.0")
53
  file_store: Dict[str, Dict[str, str]] = {}
54
  logs: List[Dict[str, str]] = []
55
 
56
- MONGO_URI = "mongodb+srv://harilogicgo_db_user:pdnh6UCMsWvuTCoi@kiddoimages.k2a4nuv.mongodb.net/?appName=KiddoImages"
 
 
 
 
57
  mongo_client = MongoClient(MONGO_URI)
58
  mongo_db = mongo_client["object_remover"]
59
  mongo_logs = mongo_db["api_logs"]
 
60
 
61
  ADMIN_MONGO_URI = os.environ.get("MONGODB_ADMIN")
62
  DEFAULT_CATEGORY_ID = "69368f722e46bd68ae188984"
@@ -71,11 +87,15 @@ def _init_admin_mongo() -> None:
71
  try:
72
  admin_client = MongoClient(ADMIN_MONGO_URI)
73
  # get_default_database() extracts database from connection string (e.g., /adminPanel)
74
- admin_db = admin_client.get_default_database()
 
 
 
 
75
  if admin_db is None:
76
- # Fallback if no database in URI
77
- admin_db = admin_client["admin"]
78
- log.warning("No database in connection string, defaulting to 'admin'")
79
 
80
  admin_media_clicks = admin_db["media_clicks"]
81
  log.info(
@@ -112,6 +132,50 @@ def _admin_logging_status() -> Dict[str, object]:
112
  }
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def _build_ai_edit_daily_count(
116
  existing: Optional[List[Dict[str, object]]],
117
  today: date,
@@ -223,6 +287,7 @@ class InpaintRequest(BaseModel):
223
  mask_id: str
224
  invert_mask: bool = True # True => selected/painted area is removed
225
  passthrough: bool = False # If True, return the original image unchanged
 
226
  user_id: Optional[str] = None
227
  category_id: Optional[str] = None
228
 
@@ -382,47 +447,18 @@ def logging_status(_: None = Depends(bearer_auth)) -> Dict[str, object]:
382
 
383
  @app.post("/upload-image")
384
  def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
385
- ext = os.path.splitext(image.filename)[1] or ".png"
386
- file_id = str(uuid.uuid4())
387
- stored_name = f"{file_id}{ext}"
388
- stored_path = os.path.join(UPLOAD_DIR, stored_name)
389
- with open(stored_path, "wb") as f:
390
- shutil.copyfileobj(image.file, f)
391
- file_store[file_id] = {
392
- "type": "image",
393
- "filename": image.filename,
394
- "stored_name": stored_name,
395
- "path": stored_path,
396
- "timestamp": datetime.utcnow().isoformat(),
397
- }
398
  logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()})
399
  return {"id": file_id, "filename": image.filename}
400
 
401
 
402
  @app.post("/upload-mask")
403
  def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
404
- ext = os.path.splitext(mask.filename)[1] or ".png"
405
- file_id = str(uuid.uuid4())
406
- stored_name = f"{file_id}{ext}"
407
- stored_path = os.path.join(UPLOAD_DIR, stored_name)
408
- with open(stored_path, "wb") as f:
409
- shutil.copyfileobj(mask.file, f)
410
- file_store[file_id] = {
411
- "type": "mask",
412
- "filename": mask.filename,
413
- "stored_name": stored_name,
414
- "path": stored_path,
415
- "timestamp": datetime.utcnow().isoformat(),
416
- }
417
  logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()})
418
  return {"id": file_id, "filename": mask.filename}
419
 
420
 
421
- def _load_rgba_image(path: str) -> Image.Image:
422
- img = Image.open(path)
423
- return img.convert("RGBA")
424
-
425
-
426
  def _compress_image(image_path: str, output_path: str, quality: int = 85) -> None:
427
  """
428
  Compress an image to reduce file size.
@@ -503,14 +539,8 @@ def inpaint(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth
503
  compressed_url = None
504
 
505
  try:
506
- if req.image_id not in file_store or file_store[req.image_id]["type"] != "image":
507
- raise HTTPException(status_code=404, detail="image_id not found")
508
-
509
- if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask":
510
- raise HTTPException(status_code=404, detail="mask_id not found")
511
-
512
- img_rgba = _load_rgba_image(file_store[req.image_id]["path"])
513
- mask_img = Image.open(file_store[req.mask_id]["path"])
514
  mask_rgba = _load_rgba_mask_from_image(mask_img)
515
 
516
  if req.passthrough:
@@ -519,7 +549,8 @@ def inpaint(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth
519
  result = process_inpaint(
520
  np.array(img_rgba),
521
  mask_rgba,
522
- invert_mask=req.invert_mask
 
523
  )
524
 
525
  output_name = f"output_{uuid.uuid4().hex}.png"
@@ -608,19 +639,19 @@ def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_
608
  result_name = None
609
 
610
  try:
611
- if req.image_id not in file_store or file_store[req.image_id]["type"] != "image":
612
- raise HTTPException(status_code=404, detail="image_id not found")
613
- if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask":
614
- raise HTTPException(status_code=404, detail="mask_id not found")
615
-
616
- img_rgba = _load_rgba_image(file_store[req.image_id]["path"])
617
- mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA
618
  mask_rgba = _load_rgba_mask_from_image(mask_img)
619
 
620
  if req.passthrough:
621
  result = np.array(img_rgba.convert("RGB"))
622
  else:
623
- result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
 
 
 
 
 
624
  result_name = f"output_{uuid.uuid4().hex}.png"
625
  result_path = os.path.join(OUTPUT_DIR, result_name)
626
  Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
@@ -662,6 +693,7 @@ def inpaint_multipart(
662
  invert_mask: bool = True,
663
  mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original)
664
  passthrough: bool = False,
 
665
  user_id: Optional[str] = Form(None),
666
  category_id: Optional[str] = Form(None),
667
  _: None = Depends(bearer_auth),
@@ -774,7 +806,12 @@ def inpaint_multipart(
774
  actual_invert = invert_mask # Use default True for painted masks
775
  log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted)
776
 
777
- result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert)
 
 
 
 
 
778
  result_name = f"output_{uuid.uuid4().hex}.png"
779
  result_path = os.path.join(OUTPUT_DIR, result_name)
780
  Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
@@ -1930,4 +1967,4 @@ def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse:
1930
 
1931
  # @app.get("/logs")
1932
  # def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse:
1933
- # return JSONResponse(content=logs)
 
4
  import shutil
5
  import re
6
  from datetime import datetime, timedelta, date
7
+ from io import BytesIO
8
  from typing import Dict, List, Optional
9
 
10
  import numpy as np
 
20
  )
21
  from fastapi.responses import FileResponse, JSONResponse
22
  from pydantic import BaseModel
23
+ from PIL import Image, UnidentifiedImageError
24
  import cv2
25
  import logging
26
+ from gridfs import GridFS
27
+ from gridfs.errors import NoFile
28
 
29
  from bson import ObjectId
30
  from pymongo import MongoClient
31
  import time
32
 
33
+ # Load environment variables from .env if present
34
+ try:
35
+ from dotenv import load_dotenv
36
+
37
+ load_dotenv()
38
+ except Exception:
39
+ pass
40
+
41
  logging.basicConfig(level=logging.INFO)
42
  log = logging.getLogger("api")
43
 
 
64
  file_store: Dict[str, Dict[str, str]] = {}
65
  logs: List[Dict[str, str]] = []
66
 
67
+ MONGO_URI = (
68
+ os.environ.get("MONGO_URI")
69
+ or os.environ.get("MONGODB_URI")
70
+ or "mongodb+srv://harilogicgo_db_user:pdnh6UCMsWvuTCoi@kiddoimages.k2a4nuv.mongodb.net/?appName=KiddoImages"
71
+ )
72
  mongo_client = MongoClient(MONGO_URI)
73
  mongo_db = mongo_client["object_remover"]
74
  mongo_logs = mongo_db["api_logs"]
75
+ grid_fs = GridFS(mongo_db)
76
 
77
  ADMIN_MONGO_URI = os.environ.get("MONGODB_ADMIN")
78
  DEFAULT_CATEGORY_ID = "69368f722e46bd68ae188984"
 
87
  try:
88
  admin_client = MongoClient(ADMIN_MONGO_URI)
89
  # get_default_database() extracts database from connection string (e.g., /adminPanel)
90
+ try:
91
+ admin_db = admin_client.get_default_database()
92
+ except Exception as db_err:
93
+ admin_db = None
94
+ log.warning("Admin Mongo URI has no default DB; error=%s", db_err)
95
  if admin_db is None:
96
+ # Fallback to provided default for this app
97
+ admin_db = admin_client["object_remover"]
98
+ log.warning("No database in connection string, defaulting to 'object_remover'")
99
 
100
  admin_media_clicks = admin_db["media_clicks"]
101
  log.info(
 
132
  }
133
 
134
 
135
+ def _save_upload_to_gridfs(upload: UploadFile, file_type: str) -> str:
136
+ """Store an uploaded file into GridFS and return its ObjectId string."""
137
+ data = upload.file.read()
138
+ if not data:
139
+ raise HTTPException(status_code=400, detail=f"{file_type} file is empty")
140
+ oid = grid_fs.put(
141
+ data,
142
+ filename=upload.filename or f"{file_type}.bin",
143
+ contentType=upload.content_type,
144
+ metadata={"type": file_type},
145
+ )
146
+ return str(oid)
147
+
148
+
149
+ def _read_gridfs_bytes(file_id: str, expected_type: str) -> bytes:
150
+ """Fetch raw bytes from GridFS and validate the stored type metadata."""
151
+ try:
152
+ oid = ObjectId(file_id)
153
+ except Exception:
154
+ raise HTTPException(status_code=404, detail=f"{expected_type}_id invalid")
155
+
156
+ try:
157
+ grid_out = grid_fs.get(oid)
158
+ except NoFile:
159
+ raise HTTPException(status_code=404, detail=f"{expected_type}_id not found")
160
+
161
+ meta = grid_out.metadata or {}
162
+ stored_type = meta.get("type")
163
+ if stored_type and stored_type != expected_type:
164
+ raise HTTPException(status_code=404, detail=f"{expected_type}_id not found")
165
+
166
+ return grid_out.read()
167
+
168
+
169
+ def _load_rgba_image_from_gridfs(file_id: str, expected_type: str) -> Image.Image:
170
+ """Load an image from GridFS and convert to RGBA."""
171
+ data = _read_gridfs_bytes(file_id, expected_type)
172
+ try:
173
+ img = Image.open(BytesIO(data))
174
+ except UnidentifiedImageError:
175
+ raise HTTPException(status_code=422, detail=f"{expected_type} is not a valid image")
176
+ return img.convert("RGBA")
177
+
178
+
179
  def _build_ai_edit_daily_count(
180
  existing: Optional[List[Dict[str, object]]],
181
  today: date,
 
287
  mask_id: str
288
  invert_mask: bool = True # True => selected/painted area is removed
289
  passthrough: bool = False # If True, return the original image unchanged
290
+ prompt: Optional[str] = None # Optional: describe what to remove
291
  user_id: Optional[str] = None
292
  category_id: Optional[str] = None
293
 
 
447
 
448
  @app.post("/upload-image")
449
  def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
450
+ file_id = _save_upload_to_gridfs(image, "image")
 
 
 
 
 
 
 
 
 
 
 
 
451
  logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()})
452
  return {"id": file_id, "filename": image.filename}
453
 
454
 
455
  @app.post("/upload-mask")
456
  def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
457
+ file_id = _save_upload_to_gridfs(mask, "mask")
 
 
 
 
 
 
 
 
 
 
 
 
458
  logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()})
459
  return {"id": file_id, "filename": mask.filename}
460
 
461
 
 
 
 
 
 
462
  def _compress_image(image_path: str, output_path: str, quality: int = 85) -> None:
463
  """
464
  Compress an image to reduce file size.
 
539
  compressed_url = None
540
 
541
  try:
542
+ img_rgba = _load_rgba_image_from_gridfs(req.image_id, "image")
543
+ mask_img = _load_rgba_image_from_gridfs(req.mask_id, "mask")
 
 
 
 
 
 
544
  mask_rgba = _load_rgba_mask_from_image(mask_img)
545
 
546
  if req.passthrough:
 
549
  result = process_inpaint(
550
  np.array(img_rgba),
551
  mask_rgba,
552
+ invert_mask=req.invert_mask,
553
+ prompt=req.prompt,
554
  )
555
 
556
  output_name = f"output_{uuid.uuid4().hex}.png"
 
639
  result_name = None
640
 
641
  try:
642
+ img_rgba = _load_rgba_image_from_gridfs(req.image_id, "image")
643
+ mask_img = _load_rgba_image_from_gridfs(req.mask_id, "mask") # may be RGB/gray/RGBA
 
 
 
 
 
644
  mask_rgba = _load_rgba_mask_from_image(mask_img)
645
 
646
  if req.passthrough:
647
  result = np.array(img_rgba.convert("RGB"))
648
  else:
649
+ result = process_inpaint(
650
+ np.array(img_rgba),
651
+ mask_rgba,
652
+ invert_mask=req.invert_mask,
653
+ prompt=req.prompt,
654
+ )
655
  result_name = f"output_{uuid.uuid4().hex}.png"
656
  result_path = os.path.join(OUTPUT_DIR, result_name)
657
  Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
 
693
  invert_mask: bool = True,
694
  mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original)
695
  passthrough: bool = False,
696
+ prompt: Optional[str] = Form(None),
697
  user_id: Optional[str] = Form(None),
698
  category_id: Optional[str] = Form(None),
699
  _: None = Depends(bearer_auth),
 
806
  actual_invert = invert_mask # Use default True for painted masks
807
  log.info("Using invert_mask=%s (mask_is_painted=%s)", actual_invert, mask_is_painted)
808
 
809
+ result = process_inpaint(
810
+ np.array(img),
811
+ mask_rgba,
812
+ invert_mask=actual_invert,
813
+ prompt=prompt,
814
+ )
815
  result_name = f"output_{uuid.uuid4().hex}.png"
816
  result_path = os.path.join(OUTPUT_DIR, result_name)
817
  Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
 
1967
 
1968
  # @app.get("/logs")
1969
  # def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse:
1970
+ # return JSONResponse(content=logs)
app.py CHANGED
@@ -2,10 +2,8 @@
2
  # Model based on: https://github.com/saic-mdal/lama
3
 
4
  import numpy as np
5
- import pandas as pd
6
  import streamlit as st
7
  import os
8
- from datetime import datetime
9
  from PIL import Image
10
  from streamlit_drawable_canvas import st_canvas
11
  from io import BytesIO
 
2
  # Model based on: https://github.com/saic-mdal/lama
3
 
4
  import numpy as np
 
5
  import streamlit as st
6
  import os
 
7
  from PIL import Image
8
  from streamlit_drawable_canvas import st_canvas
9
  from io import BytesIO
assets/big-lama.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9
3
- size 205669692
 
 
 
 
requirements.txt CHANGED
@@ -1,29 +1,11 @@
1
- torch
2
- torchvision
3
- numpy
4
- opencv-python-headless
5
- matplotlib
6
- streamlit==1.24.1
7
- gradio==5.47.0
8
- streamlit-drawable-canvas==0.9.0
9
- pyyaml
10
- tqdm
11
- easydict==1.9.0
12
- scikit-image
13
- scipy>=1.14.1
14
- tensorflow
15
- joblib
16
- pandas
17
- albumentations==0.5.2
18
- hydra-core==1.1.0
19
- pytorch-lightning==1.2.9
20
- tabulate
21
- kornia==0.5.0
22
- webdataset
23
- packaging
24
- wldhx.yadisk-direct
25
- altair<5
26
  fastapi
27
  uvicorn[standard]
28
  python-multipart
29
- pymongo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
  python-multipart
4
+ google-genai>=1.38.0
5
+ google-generativeai
6
+ python-dotenv
7
+ numpy
8
+ opencv-python-headless
9
+ Pillow
10
+ pymongo
11
+ streamlit==1.24.1
src/core.py CHANGED
@@ -1,556 +1,216 @@
1
- import base64
2
- import json
3
  import os
4
- import re
5
- import time
6
- import uuid
7
  from io import BytesIO
8
- from pathlib import Path
9
- import cv2
10
-
11
- # For inpainting
12
-
13
- import numpy as np
14
- import pandas as pd
15
- import streamlit as st
16
- from PIL import Image
17
- from streamlit_drawable_canvas import st_canvas
18
-
19
-
20
- import argparse
21
- import io
22
- import multiprocessing
23
- from typing import Union
24
-
25
- import torch
26
 
 
27
  try:
28
- torch._C._jit_override_can_fuse_on_cpu(False)
29
- torch._C._jit_override_can_fuse_on_gpu(False)
30
- torch._C._jit_set_texpr_fuser_enabled(False)
31
- torch._C._jit_set_nvfuser_enabled(False)
32
- except:
33
  pass
34
 
35
- from src.helper import (
36
- download_model,
37
- load_img,
38
- norm_img,
39
- numpy_to_bytes,
40
- pad_img_to_modulo,
41
- resize_max_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
 
43
 
44
- NUM_THREADS = str(multiprocessing.cpu_count())
45
-
46
- os.environ["OMP_NUM_THREADS"] = NUM_THREADS
47
- os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
48
- os.environ["MKL_NUM_THREADS"] = NUM_THREADS
49
- os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
50
- os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
51
- if os.environ.get("CACHE_DIR"):
52
- os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
53
 
54
- #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
 
 
 
 
 
55
 
56
- # For Seam-carving
57
 
58
- from scipy import ndimage as ndi
59
-
60
- SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
61
- SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
62
- DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
63
- ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
64
- MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
65
- USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
66
-
67
- device_str = os.environ.get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
68
- device = torch.device(device_str)
69
- model_path = "./assets/big-lama.pt"
70
- model = torch.jit.load(model_path, map_location=device)
71
- model = model.to(device)
72
- model.eval()
73
-
74
-
75
- ########################################
76
- # UTILITY CODE
77
- ########################################
78
-
79
-
80
- def visualize(im, boolmask=None, rotate=False):
81
- vis = im.astype(np.uint8)
82
- if boolmask is not None:
83
- vis[np.where(boolmask == False)] = SEAM_COLOR
84
- if rotate:
85
- vis = rotate_image(vis, False)
86
- cv2.imshow("visualization", vis)
87
- cv2.waitKey(1)
88
- return vis
89
-
90
- def resize(image, width):
91
- dim = None
92
- h, w = image.shape[:2]
93
- dim = (width, int(h * width / float(w)))
94
- image = image.astype('float32')
95
- return cv2.resize(image, dim)
96
-
97
- def rotate_image(image, clockwise):
98
- k = 1 if clockwise else 3
99
- return np.rot90(image, k)
100
-
101
-
102
- ########################################
103
- # ENERGY FUNCTIONS
104
- ########################################
105
-
106
- def backward_energy(im):
107
  """
108
- Simple gradient magnitude energy map.
 
 
109
  """
110
- xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
111
- ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
112
-
113
- grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
114
-
115
- # vis = visualize(grad_mag)
116
- # cv2.imwrite("backward_energy_demo.jpg", vis)
117
-
118
- return grad_mag
119
 
120
- def forward_energy(im):
121
- """
122
- Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
123
- by Rubinstein, Shamir, Avidan.
124
- Vectorized code adapted from
125
- https://github.com/axu2/improved-seam-carving.
126
- """
127
- h, w = im.shape[:2]
128
- im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
129
-
130
- energy = np.zeros((h, w))
131
- m = np.zeros((h, w))
132
-
133
- U = np.roll(im, 1, axis=0)
134
- L = np.roll(im, 1, axis=1)
135
- R = np.roll(im, -1, axis=1)
136
-
137
- cU = np.abs(R - L)
138
- cL = np.abs(U - L) + cU
139
- cR = np.abs(U - R) + cU
140
-
141
- for i in range(1, h):
142
- mU = m[i-1]
143
- mL = np.roll(mU, 1)
144
- mR = np.roll(mU, -1)
145
-
146
- mULR = np.array([mU, mL, mR])
147
- cULR = np.array([cU[i], cL[i], cR[i]])
148
- mULR += cULR
149
-
150
- argmins = np.argmin(mULR, axis=0)
151
- m[i] = np.choose(argmins, mULR)
152
- energy[i] = np.choose(argmins, cULR)
153
-
154
- # vis = visualize(energy)
155
- # cv2.imwrite("forward_energy_demo.jpg", vis)
156
-
157
- return energy
158
-
159
- ########################################
160
- # SEAM HELPER FUNCTIONS
161
- ########################################
162
-
163
- def add_seam(im, seam_idx):
164
- """
165
- Add a vertical seam to a 3-channel color image at the indices provided
166
- by averaging the pixels values to the left and right of the seam.
167
- Code adapted from https://github.com/vivianhylee/seam-carving.
168
- """
169
- h, w = im.shape[:2]
170
- output = np.zeros((h, w + 1, 3))
171
- for row in range(h):
172
- col = seam_idx[row]
173
- for ch in range(3):
174
- if col == 0:
175
- p = np.mean(im[row, col: col + 2, ch])
176
- output[row, col, ch] = im[row, col, ch]
177
- output[row, col + 1, ch] = p
178
- output[row, col + 1:, ch] = im[row, col:, ch]
179
- else:
180
- p = np.mean(im[row, col - 1: col + 1, ch])
181
- output[row, : col, ch] = im[row, : col, ch]
182
- output[row, col, ch] = p
183
- output[row, col + 1:, ch] = im[row, col:, ch]
184
-
185
- return output
186
-
187
- def add_seam_grayscale(im, seam_idx):
188
- """
189
- Add a vertical seam to a grayscale image at the indices provided
190
- by averaging the pixels values to the left and right of the seam.
191
- """
192
- h, w = im.shape[:2]
193
- output = np.zeros((h, w + 1))
194
- for row in range(h):
195
- col = seam_idx[row]
196
- if col == 0:
197
- p = np.mean(im[row, col: col + 2])
198
- output[row, col] = im[row, col]
199
- output[row, col + 1] = p
200
- output[row, col + 1:] = im[row, col:]
201
- else:
202
- p = np.mean(im[row, col - 1: col + 1])
203
- output[row, : col] = im[row, : col]
204
- output[row, col] = p
205
- output[row, col + 1:] = im[row, col:]
206
-
207
- return output
208
-
209
- def remove_seam(im, boolmask):
210
- h, w = im.shape[:2]
211
- boolmask3c = np.stack([boolmask] * 3, axis=2)
212
- return im[boolmask3c].reshape((h, w - 1, 3))
213
-
214
- def remove_seam_grayscale(im, boolmask):
215
- h, w = im.shape[:2]
216
- return im[boolmask].reshape((h, w - 1))
217
-
218
- def get_minimum_seam(im, mask=None, remove_mask=None):
219
  """
220
- DP algorithm for finding the seam of minimum energy. Code adapted from
221
- https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
222
  """
223
- h, w = im.shape[:2]
224
- energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
225
- M = energyfn(im)
226
-
227
- if mask is not None:
228
- M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
229
-
230
- # give removal mask priority over protective mask by using larger negative value
231
- if remove_mask is not None:
232
- M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
233
-
234
- seam_idx, boolmask = compute_shortest_path(M, im, h, w)
235
-
236
- return np.array(seam_idx), boolmask
237
-
238
- def compute_shortest_path(M, im, h, w):
239
- backtrack = np.zeros_like(M, dtype=np.int_)
240
-
241
-
242
- # populate DP matrix
243
- for i in range(1, h):
244
- for j in range(0, w):
245
- if j == 0:
246
- idx = np.argmin(M[i - 1, j:j + 2])
247
- backtrack[i, j] = idx + j
248
- min_energy = M[i-1, idx + j]
249
- else:
250
- idx = np.argmin(M[i - 1, j - 1:j + 2])
251
- backtrack[i, j] = idx + j - 1
252
- min_energy = M[i - 1, idx + j - 1]
253
-
254
- M[i, j] += min_energy
255
-
256
- # backtrack to find path
257
- seam_idx = []
258
- boolmask = np.ones((h, w), dtype=np.bool_)
259
- j = np.argmin(M[-1])
260
- for i in range(h-1, -1, -1):
261
- boolmask[i, j] = False
262
- seam_idx.append(j)
263
- j = backtrack[i, j]
264
-
265
- seam_idx.reverse()
266
- return seam_idx, boolmask
267
-
268
- ########################################
269
- # MAIN ALGORITHM
270
- ########################################
271
-
272
- def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
273
- for _ in range(num_remove):
274
- seam_idx, boolmask = get_minimum_seam(im, mask)
275
- if vis:
276
- visualize(im, boolmask, rotate=rot)
277
- im = remove_seam(im, boolmask)
278
- if mask is not None:
279
- mask = remove_seam_grayscale(mask, boolmask)
280
- return im, mask
281
-
282
-
283
- def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
284
- seams_record = []
285
- temp_im = im.copy()
286
- temp_mask = mask.copy() if mask is not None else None
287
-
288
- for _ in range(num_add):
289
- seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
290
- if vis:
291
- visualize(temp_im, boolmask, rotate=rot)
292
-
293
- seams_record.append(seam_idx)
294
- temp_im = remove_seam(temp_im, boolmask)
295
- if temp_mask is not None:
296
- temp_mask = remove_seam_grayscale(temp_mask, boolmask)
297
-
298
- seams_record.reverse()
299
-
300
- for _ in range(num_add):
301
- seam = seams_record.pop()
302
- im = add_seam(im, seam)
303
- if vis:
304
- visualize(im, rotate=rot)
305
- if mask is not None:
306
- mask = add_seam_grayscale(mask, seam)
307
-
308
- # update the remaining seam indices
309
- for remaining_seam in seams_record:
310
- remaining_seam[np.where(remaining_seam >= seam)] += 2
311
-
312
- return im, mask
313
-
314
- ########################################
315
- # MAIN DRIVER FUNCTIONS
316
- ########################################
317
-
318
- def seam_carve(im, dy, dx, mask=None, vis=False):
319
- im = im.astype(np.float64)
320
- h, w = im.shape[:2]
321
- assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
322
-
323
- if mask is not None:
324
- mask = mask.astype(np.float64)
325
-
326
- output = im
327
-
328
- if dx < 0:
329
- output, mask = seams_removal(output, -dx, mask, vis)
330
-
331
- elif dx > 0:
332
- output, mask = seams_insertion(output, dx, mask, vis)
333
-
334
- if dy < 0:
335
- output = rotate_image(output, True)
336
- if mask is not None:
337
- mask = rotate_image(mask, True)
338
- output, mask = seams_removal(output, -dy, mask, vis, rot=True)
339
- output = rotate_image(output, False)
340
-
341
- elif dy > 0:
342
- output = rotate_image(output, True)
343
- if mask is not None:
344
- mask = rotate_image(mask, True)
345
- output, mask = seams_insertion(output, dy, mask, vis, rot=True)
346
- output = rotate_image(output, False)
347
-
348
- return output
349
-
350
-
351
- def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
352
- im = im.astype(np.float64)
353
- rmask = rmask.astype(np.float64)
354
- if mask is not None:
355
- mask = mask.astype(np.float64)
356
- output = im
357
-
358
- h, w = im.shape[:2]
359
-
360
- if horizontal_removal:
361
- output = rotate_image(output, True)
362
- rmask = rotate_image(rmask, True)
363
- if mask is not None:
364
- mask = rotate_image(mask, True)
365
-
366
- while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
367
- seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
368
- if vis:
369
- visualize(output, boolmask, rotate=horizontal_removal)
370
- output = remove_seam(output, boolmask)
371
- rmask = remove_seam_grayscale(rmask, boolmask)
372
- if mask is not None:
373
- mask = remove_seam_grayscale(mask, boolmask)
374
-
375
- num_add = (h if horizontal_removal else w) - output.shape[1]
376
- output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
377
- if horizontal_removal:
378
- output = rotate_image(output, False)
379
-
380
- return output
381
-
382
-
383
-
384
- def s_image(im,mask,vs,hs,mode="resize"):
385
- im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
386
- mask = 255-mask[:,:,3]
387
- h, w = im.shape[:2]
388
- if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
389
- im = resize(im, width=DOWNSIZE_WIDTH)
390
- if mask is not None:
391
- mask = resize(mask, width=DOWNSIZE_WIDTH)
392
-
393
- # image resize mode
394
- if mode=="resize":
395
- dy = hs#reverse
396
- dx = vs#reverse
397
- assert dy is not None and dx is not None
398
- output = seam_carve(im, dy, dx, mask, False)
399
-
400
-
401
- # object removal mode
402
- elif mode=="remove":
403
- assert mask is not None
404
- output = object_removal(im, mask, None, False, True)
405
-
406
- return output
407
-
408
-
409
- ##### Inpainting helper code
410
-
411
- def run(image, mask):
412
  """
413
- image: [C, H, W]
414
- mask: [1, H, W]
415
- return: BGR IMAGE
416
  """
417
- origin_height, origin_width = image.shape[1:]
418
- image = pad_img_to_modulo(image, mod=8)
419
- mask = pad_img_to_modulo(mask, mod=8)
420
-
421
- mask = (mask > 0) * 1
422
- image = torch.from_numpy(image).unsqueeze(0).to(device)
423
- mask = torch.from_numpy(mask).unsqueeze(0).to(device)
424
-
425
- start = time.time()
426
- with torch.no_grad():
427
- inpainted_image = model(image, mask)
428
-
429
- print(f"process time: {(time.time() - start)*1000}ms")
430
- cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
431
- cur_res = cur_res[0:origin_height, 0:origin_width, :]
432
- cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
433
- cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
434
- return cur_res
435
-
436
-
437
- def get_args_parser():
438
- parser = argparse.ArgumentParser()
439
- parser.add_argument("--port", default=8080, type=int)
440
- parser.add_argument("--device", default="cuda", type=str)
441
- parser.add_argument("--debug", action="store_true")
442
- return parser.parse_args()
443
 
 
 
 
444
 
445
- def process_inpaint(image, mask, invert_mask=True):
446
- """
447
- Process inpainting - handles both alpha-based masks and RGB-based masks.
448
- Preserves original image quality and dimensions.
449
- Reference: https://huggingface.co/spaces/aryadytm/remove-photo-object
450
- """
451
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
452
- original_shape = image.shape # (H, W, C)
453
- interpolation = cv2.INTER_CUBIC
454
-
455
- # Preserve original size - only resize if absolutely necessary for memory/performance
456
- # Keep original quality by preserving dimensions
457
- max_dimension = max(image.shape[:2])
458
- # Don't resize unless image is extremely large (over 3000px) to preserve quality
459
- if max_dimension > 3000:
460
- size_limit = 3000
461
- print(f"Very large image detected ({max_dimension}px), resizing to {size_limit}px for processing")
462
- else:
463
- size_limit = max_dimension # Keep original size to preserve quality
464
- print(f"Preserving original image size: {max_dimension}px (no resize)")
465
-
466
- print(f"Origin image shape: {original_shape}")
467
-
468
- # Resize image only if needed
469
- if size_limit < max_dimension:
470
- image_resized = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
471
- print(f"Resized image shape: {image_resized.shape}")
472
- else:
473
- image_resized = image
474
- print(f"Image not resized: {image_resized.shape}")
475
-
476
- image = norm_img(image_resized)
477
-
478
- # Handle mask: check if we should use alpha channel or RGB channels
479
- alpha_channel = mask[:,:,3]
480
- rgb_channels = mask[:,:,:3]
481
-
482
- # Check if alpha is meaningful (not all 255)
483
- alpha_mean = alpha_channel.mean()
484
-
485
- if alpha_mean < 240:
486
- # Alpha channel is meaningful (has transparent areas)
487
- # Reference model logic: mask = 255-mask[:,:,3]
488
- # alpha=0 (transparent) → 255 (white/remove)
489
- # alpha=255 (opaque) → 0 (black/keep)
490
- mask = 255 - alpha_channel
491
- transparent_count = int((alpha_channel < 128).sum())
492
- print(f"Using alpha channel: {transparent_count} transparent pixels → white (to remove)")
493
- # For alpha-based masks: invert_mask=True means keep current (white=remove is correct)
494
- # invert_mask=False means invert (white becomes black)
495
- if not invert_mask:
496
- mask = 255 - mask
497
- print(f"Applied invert_mask=False: inverted alpha-based mask")
498
- else:
499
- # Alpha is mostly opaque (255), use RGB channels instead
500
- # RGB masks: white (255) = remove, black (0) = keep (standard convention)
501
- gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
502
- mask = (gray > 128).astype(np.uint8) * 255
503
- white_count = int((mask > 128).sum())
504
- print(f"Using RGB channels: {white_count} white pixels (to remove)")
505
- # For RGB-based masks: white=remove is already correct
506
- # invert_mask=False means we want black=remove (invert it)
507
- if not invert_mask:
508
- mask = 255 - mask # invert: white becomes black, black becomes white
509
- print(f"Applied invert_mask=False: inverted RGB mask (now black=remove)")
510
-
511
- # Resize mask to match image dimensions (always force exact match)
512
- target_h, target_w = image_resized.shape[:2]
513
- if mask.shape[:2] != (target_h, target_w):
514
- mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
515
-
516
- # Debug: log final mask statistics
517
- mask_nonzero = int((mask > 128).sum())
518
- mask_total = mask.shape[0] * mask.shape[1]
519
- print(f"Final mask before normalization: {mask_nonzero}/{mask_total} pixels marked for removal ({100*mask_nonzero/mask_total:.2f}%)")
520
-
521
- if mask_nonzero < 10:
522
- print("ERROR: Mask is empty or almost empty! Returning original image.")
523
- # Return original image at original size
524
- original_rgb = (image_resized * 255).astype(np.uint8)
525
- return cv2.resize(cv2.cvtColor(original_rgb, cv2.COLOR_RGB2BGR),
526
- (original_shape[1], original_shape[0]),
527
- interpolation=cv2.INTER_CUBIC)
528
-
529
- # Verify mask is correct before normalization
530
- print(f"Mask verification: {mask_nonzero} pixels will be removed, shape: {mask.shape}")
531
-
532
- mask = norm_img(mask)
533
-
534
- # Verify normalized mask
535
- mask_normalized_ones = int((mask > 0.5).sum())
536
- print(f"After normalization: {mask_normalized_ones} pixels marked for removal (value > 0.5)")
537
-
538
- # Run inpainting
539
- print("Running LaMa model for inpainting...")
540
- res_np_img = run(image, mask)
541
- print(f"Inpainting complete. Output shape: {res_np_img.shape}")
542
-
543
- # Verify output changed
544
- original_for_compare = (image_resized * 255).astype(np.uint8)
545
- original_bgr = cv2.cvtColor(original_for_compare, cv2.COLOR_RGB2BGR)
546
- diff = np.abs(res_np_img.astype(np.float32) - original_bgr.astype(np.float32))
547
- diff_pixels = int((diff.sum(axis=2) > 10).sum()) # Pixels that changed by more than 10 in any channel
548
- print(f"Output verification: {diff_pixels} pixels differ from input (should be > 0 if inpainting worked)")
549
-
550
- # Resize back to original dimensions if we resized (use LANCZOS4 for better quality)
551
- if size_limit < max_dimension:
552
- res_np_img = cv2.resize(res_np_img, (original_shape[1], original_shape[0]),
553
- interpolation=cv2.INTER_LANCZOS4)
554
- print(f"Resized output back to original size: {res_np_img.shape}")
555
-
556
- return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
 
1
+ import logging
 
2
  import os
 
 
 
3
  from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Load environment variables from .env if present (helps local dev)
6
  try:
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+ except Exception:
 
11
  pass
12
 
13
+ import base64
14
+ import cv2
15
+ import numpy as np
16
+ from PIL import Image
17
+ import google.generativeai as genai
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+ # Remote inference configuration (Gemini API key only; no Vertex required)
22
+ DEFAULT_MODEL_ID = os.environ.get("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image")
23
+ DEFAULT_PROMPT = os.environ.get(
24
+ "GEMINI_IMAGE_PROMPT",
25
+ (
26
+ "TASK TYPE: STRICT IMAGE INPAINTING — OBJECT REMOVAL ONLY\n\n"
27
+ "You are given:\n"
28
+ "1) An original image\n"
29
+ "2) A binary mask image\n\n"
30
+ "MASK RULE (MANDATORY):\n"
31
+ "• White pixels (#FFFFFF) indicate the exact region to be REMOVED.\n"
32
+ "• Black pixels (#000000) indicate regions that MUST remain completely unchanged.\n\n"
33
+ "PRIMARY OBJECTIVE:\n"
34
+ "Completely delete everything inside the white masked area.\n"
35
+ "The object in the white region must be fully removed with no visible remnants,\n"
36
+ "no partial shapes, no outlines, no shadows, and no color traces.\n\n"
37
+ "INPAINTING INSTRUCTIONS:\n"
38
+ "Ignore the content of the white masked area entirely.\n"
39
+ "Reconstruct that region using ONLY surrounding background information.\n"
40
+ "Extend nearby background textures, patterns, and structures naturally.\n"
41
+ "Match lighting direction, brightness, contrast, color temperature, and noise.\n"
42
+ "Continue edges, lines, and surfaces realistically across the removed area.\n"
43
+ "Blend boundaries smoothly so the edit is visually undetectable.\n\n"
44
+ "STRICT CONSTRAINTS:\n"
45
+ "• Do NOT generate or keep any part of the removed object.\n"
46
+ "• Do NOT invent new objects or details.\n"
47
+ "• Do NOT repaint, modify, blur, or enhance any black (unmasked) area.\n"
48
+ "• Do NOT change the original image composition.\n"
49
+ "• Do NOT change camera angle, perspective, or scale.\n\n"
50
+ "QUALITY REQUIREMENTS:\n"
51
+ "• No ghosting or transparent object remains.\n"
52
+ "• No edge halos or smearing.\n"
53
+ "• No repeated textures or patchy fills.\n"
54
+ "• Result must look like the object never existed.\n\n"
55
+ "FAILURE CONDITIONS (MUST BE AVOIDED):\n"
56
+ "If any object fragment, outline, shadow, or color from the removed object\n"
57
+ "is still visible, the result is incorrect and must be re-generated."
58
+ ),
59
  )
60
+ _GENAI_MODEL: genai.GenerativeModel | None = None
61
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def _resize_mask(mask: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray:
64
+ """Resize mask to match the target height/width."""
65
+ target_h, target_w = target_hw
66
+ if mask.shape[:2] == (target_h, target_w):
67
+ return mask
68
+ return cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
69
 
 
70
 
71
+ def _binary_mask_from_rgba(mask: np.ndarray, invert_mask: bool) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
+ Normalize incoming RGBA masks to a 0/255 binary mask.
74
+ - Transparent alpha (0) is treated as "remove"
75
+ - White/bright RGB is treated as "remove" when alpha is mostly opaque
76
  """
77
+ if mask.shape[2] == 3:
78
+ alpha_channel = np.ones(mask.shape[:2], dtype=np.uint8) * 255
79
+ rgb_channels = mask
80
+ else:
81
+ alpha_channel = mask[:, :, 3]
82
+ rgb_channels = mask[:, :, :3]
 
 
 
83
 
84
+ # If alpha carries information, prefer it
85
+ if alpha_channel.mean() < 240:
86
+ mask_bw = np.where(alpha_channel < 128, 255, 0).astype(np.uint8)
87
+ else:
88
+ gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
89
+ mask_bw = np.where(gray > 128, 255, 0).astype(np.uint8)
90
+
91
+ if not invert_mask:
92
+ mask_bw = 255 - mask_bw
93
+
94
+ return mask_bw
95
+
96
+
97
+ def _pil_to_png_bytes(img: Image.Image) -> bytes:
98
+ """Encode a PIL image to PNG bytes for Gemini edit endpoints."""
99
+ buf = BytesIO()
100
+ img.save(buf, format="PNG")
101
+ buf.seek(0)
102
+ return buf.getvalue()
103
+
104
+
105
+ def _get_gemini_model() -> genai.GenerativeModel:
106
+ global _GENAI_MODEL
107
+ if _GENAI_MODEL is None:
108
+ api_key = (
109
+ os.environ.get("GEMINI_API_KEY")
110
+ or os.environ.get("GOOGLE_API_KEY")
111
+ or os.environ.get("GOOGLE_GENAI_API_KEY")
112
+ )
113
+ if not api_key:
114
+ raise RuntimeError("Set Gemini API key via GEMINI_API_KEY / GOOGLE_API_KEY / GOOGLE_GENAI_API_KEY")
115
+ genai.configure(api_key=api_key)
116
+ model_id = os.environ.get("GEMINI_IMAGE_MODEL", DEFAULT_MODEL_ID)
117
+ _GENAI_MODEL = genai.GenerativeModel(model_id)
118
+ return _GENAI_MODEL
119
+
120
+
121
+ def _call_gemini_edit(
122
+ image_rgb: np.ndarray,
123
+ mask_bw: np.ndarray,
124
+ prompt: str | None,
125
+ target_size: tuple[int, int],
126
+ ) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  """
128
+ Send source image + binary mask to Gemini via API-key-only generate_content.
129
+ We include both the base image and the mask as separate parts and instruct the model to remove masked regions.
130
  """
131
+ model = _get_gemini_model()
132
+
133
+ base_image = Image.fromarray(image_rgb).convert("RGB")
134
+ mask_image = Image.fromarray(mask_bw).convert("L")
135
+
136
+ # Build a guidance image where the removal region is painted white for clarity
137
+ guidance_rgb = image_rgb.copy()
138
+ guidance_rgb[mask_bw > 0] = 255
139
+ guidance_image = Image.fromarray(guidance_rgb).convert("RGB")
140
+
141
+ base_bytes = _pil_to_png_bytes(base_image)
142
+ mask_bytes = _pil_to_png_bytes(mask_image)
143
+ guidance_bytes = _pil_to_png_bytes(guidance_image)
144
+
145
+ # Enrich prompt to explicitly describe the two images being sent
146
+ effective_prompt = (
147
+ (prompt or DEFAULT_PROMPT).strip()
148
+ + "\nIMAGE ORDER:\n"
149
+ + "Image A: Original photo with the removal region painted white.\n"
150
+ + "Image B: Binary mask (white=remove, black=keep). Use this mask to decide what to remove.\n"
151
+ )
152
+ log.info(
153
+ "Calling Gemini generate_content model=%s (mask-guided remove) mask_pixels=%d",
154
+ model.model_name,
155
+ int((mask_bw > 0).sum()),
156
+ )
157
+
158
+ # Build content parts: prompt + guidance image + mask image (explicit order)
159
+ content = [
160
+ effective_prompt,
161
+ {"mime_type": "image/png", "data": guidance_bytes},
162
+ {"mime_type": "image/png", "data": mask_bytes},
163
+ ]
164
+
165
+ response = model.generate_content(content, stream=False)
166
+
167
+ output_img: Image.Image | None = None
168
+
169
+ # Extract first image from response parts
170
+ try:
171
+ for candidate in getattr(response, "candidates", []):
172
+ parts = getattr(candidate, "content", None)
173
+ if not parts or not getattr(parts, "parts", None):
174
+ continue
175
+ for part in parts.parts:
176
+ inline = getattr(part, "inline_data", None)
177
+ if inline and inline.data:
178
+ data = inline.data
179
+ if isinstance(data, str):
180
+ data = base64.b64decode(data)
181
+ output_img = Image.open(BytesIO(data)).convert("RGB")
182
+ break
183
+ if output_img:
184
+ break
185
+ except Exception as err:
186
+ log.warning("Failed to parse Gemini response image: %s", err)
187
+
188
+ if output_img is None:
189
+ raise RuntimeError("Gemini generate_content returned no image")
190
+
191
+ # Ensure output matches original dimensions if Gemini rescaled
192
+ if output_img.size != target_size:
193
+ output_img = output_img.resize(target_size, Image.Resampling.LANCZOS)
194
+
195
+ return output_img
196
+
197
+
198
+ def process_inpaint(
199
+ image: np.ndarray,
200
+ mask: np.ndarray,
201
+ invert_mask: bool = True,
202
+ prompt: str | None = None,
203
+ ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  """
205
+ Forward inpainting to Gemini edit API using source image + mask.
 
 
206
  """
207
+ image_rgba = Image.fromarray(image).convert("RGBA")
208
+ image_rgb = np.array(image_rgba.convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ mask_rgba = np.array(Image.fromarray(mask).convert("RGBA"))
211
+ mask_bw = _binary_mask_from_rgba(mask_rgba, invert_mask)
212
+ mask_bw = _resize_mask(mask_bw, image_rgb.shape[:2])
213
 
214
+ target_size = (image_rgb.shape[1], image_rgb.shape[0]) # (width, height)
215
+ edited_image = _call_gemini_edit(image_rgb, mask_bw, prompt, target_size)
216
+ return np.array(edited_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/helper.py DELETED
@@ -1,87 +0,0 @@
1
- import os
2
- import sys
3
-
4
- from urllib.parse import urlparse
5
- import cv2
6
- import numpy as np
7
- import torch
8
- from torch.hub import download_url_to_file, get_dir
9
-
10
- LAMA_MODEL_URL = os.environ.get(
11
- "LAMA_MODEL_URL",
12
- "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
13
- )
14
-
15
-
16
- def download_model(url=LAMA_MODEL_URL):
17
- parts = urlparse(url)
18
- hub_dir = get_dir()
19
- model_dir = os.path.join(hub_dir, "checkpoints")
20
- if not os.path.isdir(model_dir):
21
- os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
22
- filename = os.path.basename(parts.path)
23
- cached_file = os.path.join(model_dir, filename)
24
- if not os.path.exists(cached_file):
25
- sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
26
- hash_prefix = None
27
- download_url_to_file(url, cached_file, hash_prefix, progress=True)
28
- return cached_file
29
-
30
-
31
- def ceil_modulo(x, mod):
32
- if x % mod == 0:
33
- return x
34
- return (x // mod + 1) * mod
35
-
36
-
37
- def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
38
- data = cv2.imencode(".jpg", image_numpy)[1]
39
- image_bytes = data.tobytes()
40
- return image_bytes
41
-
42
-
43
- def load_img(img_bytes, gray: bool = False):
44
- nparr = np.frombuffer(img_bytes, np.uint8)
45
- if gray:
46
- np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
47
- else:
48
- np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
49
- if len(np_img.shape) == 3 and np_img.shape[2] == 4:
50
- np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
51
- else:
52
- np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
53
-
54
- return np_img
55
-
56
-
57
- def norm_img(np_img):
58
- if len(np_img.shape) == 2:
59
- np_img = np_img[:, :, np.newaxis]
60
- np_img = np.transpose(np_img, (2, 0, 1))
61
- np_img = np_img.astype("float32") / 255
62
- return np_img
63
-
64
-
65
- def resize_max_size(
66
- np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
67
- ) -> np.ndarray:
68
- # Resize image's longer size to size_limit if longer size larger than size_limit
69
- h, w = np_img.shape[:2]
70
- if max(h, w) > size_limit:
71
- ratio = size_limit / max(h, w)
72
- new_w = int(w * ratio + 0.5)
73
- new_h = int(h * ratio + 0.5)
74
- return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
75
- else:
76
- return np_img
77
-
78
-
79
- def pad_img_to_modulo(img, mod):
80
- channels, height, width = img.shape
81
- out_height = ceil_modulo(height, mod)
82
- out_width = ceil_modulo(width, mod)
83
- return np.pad(
84
- img,
85
- ((0, 0), (0, out_height - height), (0, out_width - width)),
86
- mode="symmetric",
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/st_style.py DELETED
@@ -1,42 +0,0 @@
1
- button_style = """
2
- <style>
3
- div.stButton > button:first-child {
4
- background-color: rgb(255, 75, 75);
5
- color: rgb(255, 255, 255);
6
- }
7
- div.stButton > button:hover {
8
- background-color: rgb(255, 75, 75);
9
- color: rgb(255, 255, 255);
10
- }
11
- div.stButton > button:active {
12
- background-color: rgb(255, 75, 75);
13
- color: rgb(255, 255, 255);
14
- }
15
- div.stButton > button:focus {
16
- background-color: rgb(255, 75, 75);
17
- color: rgb(255, 255, 255);
18
- }
19
- .css-1cpxqw2:focus:not(:active) {
20
- background-color: rgb(255, 75, 75);
21
- border-color: rgb(255, 75, 75);
22
- color: rgb(255, 255, 255);
23
- }
24
- """
25
-
26
- style = """
27
- <style>
28
- #MainMenu {
29
- visibility: hidden;
30
- }
31
- footer {
32
- visibility: hidden;
33
- }
34
- header {
35
- visibility: hidden;
36
- }
37
- </style>
38
- """
39
-
40
-
41
- def apply_prod_style(st):
42
- return st.markdown(style, unsafe_allow_html=True)