basilbenny1002 commited on
Commit
dd7bc5d
·
verified ·
1 Parent(s): f0dcd66

Upload 9 files

Browse files
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # App package
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API package
app/api/routes.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile
5
+ from fastapi.responses import FileResponse, JSONResponse
6
+
7
+ from app.core.config import settings
8
+ from app.services.image_selector import ImageSelectorService
9
+
10
+
11
+ router = APIRouter()
12
+ selector_service = ImageSelectorService()
13
+
14
+
15
+ @router.post("/upload")
16
+ async def upload_file(file: UploadFile = File(...), user_id: str = Form(...)):
17
+ user_dir = settings.MEDIA_ROOT / user_id
18
+ user_dir.mkdir(parents=True, exist_ok=True)
19
+ dest = user_dir / file.filename
20
+ try:
21
+ with dest.open("wb") as f:
22
+ while True:
23
+ chunk = await file.read(1024 * 1024)
24
+ if not chunk:
25
+ break
26
+ f.write(chunk)
27
+ finally:
28
+ await file.close()
29
+ return {"message": f"{file.filename} saved"}
30
+
31
+
32
+ @router.post("/process")
33
+ async def process_images(
34
+ background_tasks: BackgroundTasks,
35
+ user_id: str = Form(...),
36
+ similarity: float = Form(0.87),
37
+ use_aesthetics: bool = Form(True),
38
+ ):
39
+ input_dir = settings.MEDIA_ROOT / user_id
40
+ if not input_dir.exists():
41
+ raise HTTPException(status_code=404, detail="User directory not found")
42
+ output_dir = input_dir / "output"
43
+
44
+ def _run():
45
+ selector_service.choose_best(user_id=user_id, input_dir=input_dir, output_dir=output_dir, similarity=similarity, use_aesthetics=use_aesthetics)
46
+
47
+ background_tasks.add_task(_run)
48
+ return {"status": "started"}
49
+
50
+
51
+ @router.get("/download/{user_id}")
52
+ async def download_zip(user_id: str, background_tasks: BackgroundTasks):
53
+ user_dir = settings.MEDIA_ROOT / user_id / "output"
54
+ if not user_dir.exists():
55
+ raise HTTPException(status_code=404, detail="No output for this user")
56
+ zip_path = settings.MEDIA_ROOT / f"{user_id}_output.zip"
57
+ # Create zip
58
+ from shutil import make_archive, rmtree
59
+
60
+ make_archive(str(zip_path).replace(".zip", ""), "zip", user_dir)
61
+
62
+ # After sending the file, delete the zip and the user's media folder
63
+ def _cleanup():
64
+ try:
65
+ if zip_path.exists():
66
+ zip_path.unlink()
67
+ except Exception:
68
+ pass
69
+ try:
70
+ rmtree(settings.MEDIA_ROOT / user_id, ignore_errors=True)
71
+ except Exception:
72
+ pass
73
+
74
+ background_tasks.add_task(_cleanup)
75
+ return FileResponse(zip_path, filename=f"{user_id}_output.zip")
76
+
77
+
78
+ @router.get("/progress/{user_id}")
79
+ async def get_progress(user_id: str):
80
+ return selector_service.get_progress(user_id)
81
+
82
+
83
+ @router.get("/")
84
+ async def health():
85
+ return JSONResponse(status_code=200, content={"status": "ok"})
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Core package
app/core/config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+
4
+
5
+ class Settings:
6
+ # Base directories - use /tmp for HuggingFace Spaces (ephemeral storage)
7
+ # For persistent storage, use /data if upgraded
8
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
9
+
10
+ # Check if persistent storage is available (HF Spaces upgrade)
11
+ if os.path.exists("/data"):
12
+ MEDIA_ROOT = Path("/data") / "media" / "temp"
13
+ else:
14
+ # Use ephemeral /tmp for free tier
15
+ MEDIA_ROOT = Path("/tmp") / "media" / "temp"
16
+
17
+ MEDIA_ROOT.mkdir(parents=True, exist_ok=True)
18
+
19
+ # Database: per-user embeddings under their media folder
20
+ def user_db_path(self, user_id: str) -> Path:
21
+ """Return the path to the embeddings DB for a specific user.
22
+
23
+ Ensures the parent directory exists.
24
+ """
25
+ db_path = self.MEDIA_ROOT / user_id / "embeddings.db"
26
+ db_path.parent.mkdir(parents=True, exist_ok=True)
27
+ return db_path
28
+
29
+
30
+ settings = Settings()
app/repositories/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Repositories package
app/repositories/embeddings.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from pathlib import Path
3
+ from typing import Iterable, Tuple, List
4
+
5
+
6
+ class EmbeddingsRepository:
7
+ def __init__(self, db_path: Path):
8
+ self.db_path = db_path
9
+ # Background tasks may use a different thread; allow cross-thread access
10
+ self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
11
+ self._cursor = self._conn.cursor()
12
+ self._cursor.execute(
13
+ """
14
+ CREATE TABLE IF NOT EXISTS images (
15
+ path TEXT PRIMARY KEY,
16
+ embedding BLOB
17
+ )
18
+ """
19
+ )
20
+ self._conn.commit()
21
+
22
+ def upsert(self, path: str, embedding: bytes) -> None:
23
+ self._cursor.execute(
24
+ "INSERT OR REPLACE INTO images(path, embedding) VALUES (?, ?)",
25
+ (path, embedding),
26
+ )
27
+ self._conn.commit()
28
+
29
+ def delete_many(self, paths: Iterable[str]) -> None:
30
+ self._cursor.executemany(
31
+ "DELETE FROM images WHERE path=?", ((p,) for p in paths)
32
+ )
33
+ self._conn.commit()
34
+
35
+ def list_all(self) -> List[Tuple[str, bytes]]:
36
+ self._cursor.execute("SELECT path, embedding FROM images")
37
+ return self._cursor.fetchall()
38
+
39
+ def close(self) -> None:
40
+ try:
41
+ self._cursor.close()
42
+ finally:
43
+ self._conn.close()
app/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Services package
app/services/image_selector.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import List, Optional, Dict, Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from torchvision import models, transforms
12
+
13
+ from app.repositories.embeddings import EmbeddingsRepository
14
+ from app.core.config import settings
15
+
16
+
17
+ @dataclass
18
+ class SelectionResult:
19
+ kept: List[str]
20
+ removed: List[str]
21
+
22
+
23
+ class ImageSelectorService:
24
+ def __init__(self, db_repo: Optional[EmbeddingsRepository] = None, device: Optional[str] = None):
25
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
26
+ # Load feature extractor (ResNet50 without classifier)
27
+ base = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
28
+ self.feature_extractor = torch.nn.Sequential(*list(base.children())[:-1]).to(self.device)
29
+ self.feature_extractor.eval()
30
+
31
+ self.transform = transforms.Compose(
32
+ [
33
+ transforms.Resize((224, 224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36
+ ]
37
+ )
38
+
39
+ # Note: repository is now user-scoped and created within choose_best.
40
+ # db_repo parameter is ignored to avoid accidental cross-user sharing.
41
+
42
+ # Lazy-load aesthetics predictor to avoid dependency unless needed
43
+ self._predictor = None
44
+ self._processor = None
45
+ # Progress tracking per user
46
+ # { user_id: { stage: int, percentage: int, eta_seconds: Optional[int], status: str, ... } }
47
+ self._progress: Dict[str, Dict[str, Any]] = {}
48
+
49
+ def _ensure_aesthetics(self):
50
+ if self._predictor is None or self._processor is None:
51
+ from transformers import CLIPProcessor
52
+ from aesthetics_predictor import AestheticsPredictorV1
53
+
54
+ model_id = "shunk031/aesthetics-predictor-v1-vit-large-patch14"
55
+ self._predictor = AestheticsPredictorV1.from_pretrained(model_id).to(self.device)
56
+ self._processor = CLIPProcessor.from_pretrained(model_id)
57
+
58
+ def embed_image(self, image_path: Path) -> np.ndarray:
59
+ img = Image.open(image_path).convert("RGB")
60
+ tensor = self.transform(img).unsqueeze(0).to(self.device)
61
+ with torch.no_grad():
62
+ emb = self.feature_extractor(tensor).squeeze().detach().cpu().numpy().astype(np.float32)
63
+ norm = np.linalg.norm(emb)
64
+ return emb / max(norm, 1e-8)
65
+
66
+ def add_image(self, image_path: Path, repo: EmbeddingsRepository) -> None:
67
+ emb = self.embed_image(image_path)
68
+ repo.upsert(str(image_path), emb.tobytes())
69
+
70
+ def predict_aesthetic(self, image_path: Path) -> float:
71
+ self._ensure_aesthetics()
72
+ img = Image.open(image_path).convert("RGB")
73
+ inputs = self._processor(images=img, return_tensors="pt")
74
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
75
+ with torch.no_grad():
76
+ outputs = self._predictor(**inputs)
77
+ return float(outputs.logits[0].item())
78
+
79
+ def find_similar(self, query_image: Path, threshold: float, repo: EmbeddingsRepository) -> List[str]:
80
+ q = self.embed_image(query_image)
81
+ entries = repo.list_all()
82
+ similar = []
83
+ for path, emb_blob in entries:
84
+ emb = np.frombuffer(emb_blob, dtype=np.float32)
85
+ sim = float(np.dot(q, emb))
86
+ if sim >= threshold:
87
+ similar.append(path)
88
+ return similar
89
+
90
+ def choose_best(self, user_id: str, input_dir: Path, output_dir: Path, similarity: float = 0.87, use_aesthetics: bool = True) -> SelectionResult:
91
+ input_dir = Path(input_dir)
92
+ output_dir = Path(output_dir)
93
+ output_dir.mkdir(parents=True, exist_ok=True)
94
+
95
+ # Prepare file list once
96
+ files = [input_dir / f for f in os.listdir(input_dir)]
97
+ files = [fp for fp in files if fp.is_file()]
98
+ total = len(files)
99
+
100
+ # Create user-scoped embeddings repository
101
+ repo = EmbeddingsRepository(settings.user_db_path(user_id))
102
+
103
+ # Stage 1: embeddings indexing
104
+ import time as _time
105
+ stage1_start = _time.time()
106
+ processed1 = 0
107
+ self._progress[user_id] = {
108
+ "stage": 1,
109
+ "percentage": 0,
110
+ "eta_seconds": None,
111
+ "status": "indexing",
112
+ "total_stage1": total,
113
+ "processed_stage1": 0,
114
+ "total_stage2": total,
115
+ "processed_stage2": 0,
116
+ }
117
+
118
+ for fp in files:
119
+ try:
120
+ self.add_image(fp, repo)
121
+ except Exception:
122
+ pass
123
+ processed1 += 1
124
+ elapsed = max(_time.time() - stage1_start, 1e-6)
125
+ rate = processed1 / elapsed
126
+ remaining = max(total - processed1, 0)
127
+ eta = int(remaining / rate) if rate > 0 else None
128
+ self._progress[user_id].update(
129
+ {
130
+ "stage": 1,
131
+ "percentage": int((processed1 / max(total, 1)) * 100),
132
+ "eta_seconds": eta,
133
+ "processed_stage1": processed1,
134
+ }
135
+ )
136
+
137
+ kept: List[str] = []
138
+ removed: List[str] = []
139
+
140
+ # Stage 2: selection
141
+ stage2_start = _time.time()
142
+ processed2 = 0
143
+ self._progress[user_id].update({"stage": 2, "percentage": 0, "eta_seconds": None, "status": "selecting"})
144
+
145
+ i = 0
146
+ for fp in files:
147
+ try:
148
+ similar = self.find_similar(fp, threshold=similarity, repo=repo)
149
+ except Exception:
150
+ similar = []
151
+ # Remove found images from DB immediately to avoid regrouping in later iterations
152
+ try:
153
+ if similar:
154
+ repo.delete_many(similar)
155
+ except Exception:
156
+ pass
157
+ i += 1
158
+
159
+ best_score = -1e9
160
+ best_path: Optional[str] = None
161
+ temp_dir = input_dir / str(i)
162
+ temp_dir.mkdir(exist_ok=True)
163
+
164
+ for path in similar:
165
+ path_p = Path(path)
166
+ try:
167
+ score = self.predict_aesthetic(path_p) if use_aesthetics else 0.0
168
+ except Exception:
169
+ score = 0.0
170
+ if score > best_score:
171
+ best_score = score
172
+ best_path = path
173
+ # copy to group folder for inspection
174
+ try:
175
+ dest = temp_dir / path_p.name
176
+ if not dest.exists():
177
+ dest.write_bytes(Path(path).read_bytes())
178
+ except Exception:
179
+ pass
180
+
181
+ if best_path:
182
+ # copy best to output and delete from input
183
+ try:
184
+ bp = Path(best_path)
185
+ (output_dir / bp.name).write_bytes(bp.read_bytes())
186
+ kept.append(best_path)
187
+ try:
188
+ bp.unlink()
189
+ except Exception:
190
+ pass
191
+ except Exception:
192
+ pass
193
+ # remove the rest
194
+ for path in similar:
195
+ if path != best_path:
196
+ try:
197
+ Path(path).unlink()
198
+ removed.append(path)
199
+ except Exception:
200
+ pass
201
+ # Update progress for stage 2
202
+ processed2 += 1
203
+ elapsed2 = max(_time.time() - stage2_start, 1e-6)
204
+ rate2 = processed2 / elapsed2
205
+ remaining2 = max(total - processed2, 0)
206
+ eta2 = int(remaining2 / rate2) if rate2 > 0 else None
207
+ self._progress[user_id].update(
208
+ {
209
+ "stage": 2,
210
+ "percentage": int((processed2 / max(total, 1)) * 100),
211
+ "eta_seconds": eta2,
212
+ "processed_stage2": processed2,
213
+ }
214
+ )
215
+
216
+ # Completed
217
+ self._progress[user_id].update({"stage": 2, "percentage": 100, "eta_seconds": 0, "status": "completed"})
218
+ # Ensure DB is closed before returning so the file can be deleted on Windows
219
+ try:
220
+ repo.close()
221
+ except Exception:
222
+ pass
223
+ return SelectionResult(kept=kept, removed=removed)
224
+
225
+ def get_progress(self, user_id: str) -> Dict[str, Any]:
226
+ return dict(
227
+ self._progress.get(
228
+ user_id,
229
+ {"stage": 0, "percentage": 0, "eta_seconds": None, "status": "idle"},
230
+ )
231
+ )