Spaces:
Sleeping
Sleeping
Upload 12 files
Browse files- Dockerfile +26 -0
- README.md +45 -5
- app.py +9 -0
- auth.py +97 -0
- auth_storage.py +201 -0
- email_service.py +121 -0
- env.example +30 -0
- main.py +722 -0
- requirements.txt +14 -0
- run.py +10 -0
- storage_hf.py +146 -0
- style_transfer.py +310 -0
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
PORT=7860
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# Install git first (needed for Hugging Face Spaces build process)
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
git \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# System deps (optional, pillow usually fine without extra)
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
build-essential \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
COPY requirements.txt ./
|
| 21 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 22 |
+
|
| 23 |
+
COPY . .
|
| 24 |
+
|
| 25 |
+
CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860} --proxy-headers --forwarded-allow-ips=*"]
|
| 26 |
+
|
README.md
CHANGED
|
@@ -1,10 +1,50 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Neural Style Transfer API
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Neural Style Transfer API
|
| 13 |
+
|
| 14 |
+
FastAPI backend for neural style transfer using PyTorch and VGG19.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- Style transfer processing with customizable parameters
|
| 19 |
+
- Gallery management for generated images
|
| 20 |
+
- User authentication and permission requests
|
| 21 |
+
- Image storage via Hugging Face datasets
|
| 22 |
+
|
| 23 |
+
## Environment Variables
|
| 24 |
+
|
| 25 |
+
Set these in the Hugging Face Space secrets:
|
| 26 |
+
|
| 27 |
+
- `HF_DATASET_REPO`: Your Hugging Face dataset repository (e.g., `username/style-transfer-data`)
|
| 28 |
+
- `HF_TOKEN`: Hugging Face access token with write permissions
|
| 29 |
+
- `MASTER_PASSWORD`: Master password for admin access
|
| 30 |
+
- `ADMIN_EMAIL`: Admin email for receiving permission request notifications
|
| 31 |
+
- `ALLOWED_ORIGINS`: Comma-separated list of allowed CORS origins
|
| 32 |
+
- `SMTP_HOST`: SMTP server host (optional, for email notifications)
|
| 33 |
+
- `SMTP_PORT`: SMTP server port (optional)
|
| 34 |
+
- `SMTP_USER`: SMTP username (optional)
|
| 35 |
+
- `SMTP_PASSWORD`: SMTP password (optional)
|
| 36 |
+
- `SMTP_FROM_EMAIL`: Email address to send from (optional)
|
| 37 |
+
|
| 38 |
+
## API Endpoints
|
| 39 |
+
|
| 40 |
+
- `GET /api/health` - Health check
|
| 41 |
+
- `POST /api/transfer` - Create style transfer job (requires auth)
|
| 42 |
+
- `GET /api/transfer/{job_id}` - Get job status
|
| 43 |
+
- `GET /api/gallery` - List gallery items
|
| 44 |
+
- `GET /api/gallery/{item_id}` - Get gallery item
|
| 45 |
+
- `DELETE /api/gallery/{item_id}` - Delete gallery item (requires auth)
|
| 46 |
+
- `POST /api/auth/login` - Login
|
| 47 |
+
- `POST /api/auth/requests` - Submit permission request
|
| 48 |
+
- `GET /api/auth/requests` - List requests (admin only)
|
| 49 |
+
- `POST /api/auth/requests/{id}/approve` - Approve request (admin only)
|
| 50 |
+
- `POST /api/auth/requests/{id}/reject` - Reject request (admin only)
|
app.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for Hugging Face Spaces deployment.
|
| 3 |
+
This file simply imports and exposes the FastAPI app from main.py.
|
| 4 |
+
"""
|
| 5 |
+
from main import app
|
| 6 |
+
|
| 7 |
+
# The app is already configured in main.py
|
| 8 |
+
# Hugging Face Spaces will automatically detect and serve this FastAPI app
|
| 9 |
+
#filler
|
auth.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from jose import jwt
|
| 3 |
+
from jose.exceptions import ExpiredSignatureError, JWTError
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
from fastapi import HTTPException, status, Depends
|
| 7 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
+
from passlib.context import CryptContext
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# JWT configuration
|
| 14 |
+
SECRET_KEY = os.getenv("JWT_SECRET_KEY", os.getenv("MASTER_PASSWORD", "change-me-in-production"))
|
| 15 |
+
ALGORITHM = "HS256"
|
| 16 |
+
ACCESS_TOKEN_EXPIRE_HOURS = 24 * 7 # 7 days
|
| 17 |
+
|
| 18 |
+
# Password hashing
|
| 19 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 20 |
+
|
| 21 |
+
# HTTP Bearer token scheme
|
| 22 |
+
security = HTTPBearer()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 26 |
+
"""Verify a password against a hash."""
|
| 27 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_password_hash(password: str) -> str:
|
| 31 |
+
"""Hash a password."""
|
| 32 |
+
return pwd_context.hash(password)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
| 36 |
+
"""Create a JWT access token."""
|
| 37 |
+
to_encode = data.copy()
|
| 38 |
+
if expires_delta:
|
| 39 |
+
expire = datetime.utcnow() + expires_delta
|
| 40 |
+
else:
|
| 41 |
+
expire = datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
|
| 42 |
+
to_encode.update({"exp": expire})
|
| 43 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
| 44 |
+
return encoded_jwt
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
| 48 |
+
"""Verify and decode a JWT token."""
|
| 49 |
+
try:
|
| 50 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 51 |
+
return payload
|
| 52 |
+
except ExpiredSignatureError:
|
| 53 |
+
logger.warning("Token has expired")
|
| 54 |
+
return None
|
| 55 |
+
except JWTError:
|
| 56 |
+
logger.warning("Invalid token")
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
async def get_current_user(
|
| 61 |
+
credentials: HTTPAuthorizationCredentials = Depends(security)
|
| 62 |
+
) -> Dict[str, Any]:
|
| 63 |
+
"""
|
| 64 |
+
Dependency to get the current authenticated user from JWT token.
|
| 65 |
+
"""
|
| 66 |
+
token = credentials.credentials
|
| 67 |
+
payload = verify_token(token)
|
| 68 |
+
if payload is None:
|
| 69 |
+
raise HTTPException(
|
| 70 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 71 |
+
detail="Invalid authentication credentials",
|
| 72 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 73 |
+
)
|
| 74 |
+
return payload
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
async def get_current_user_optional(
|
| 78 |
+
credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False))
|
| 79 |
+
) -> Optional[Dict[str, Any]]:
|
| 80 |
+
"""
|
| 81 |
+
Dependency to get the current user if authenticated, None otherwise.
|
| 82 |
+
"""
|
| 83 |
+
if credentials is None:
|
| 84 |
+
return None
|
| 85 |
+
token = credentials.credentials
|
| 86 |
+
payload = verify_token(token)
|
| 87 |
+
return payload
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def require_auth(func):
|
| 91 |
+
"""
|
| 92 |
+
Decorator to require authentication for an endpoint.
|
| 93 |
+
"""
|
| 94 |
+
async def wrapper(*args, **kwargs):
|
| 95 |
+
# This will be handled by the Depends(get_current_user) in the route
|
| 96 |
+
return await func(*args, **kwargs)
|
| 97 |
+
return wrapper
|
auth_storage.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import HfApi, CommitOperationAdd, create_commit, hf_hub_url
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
USERS_FILE_PATH = "auth/users.json"
|
| 13 |
+
REQUESTS_FILE_PATH = "auth/requests.json"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_dataset_resolve_url(repo_id: str, path_in_repo: str, revision: str = "main") -> str:
|
| 17 |
+
"""
|
| 18 |
+
Build a CDN-resolved URL for a file stored in a Hugging Face dataset repo.
|
| 19 |
+
"""
|
| 20 |
+
return hf_hub_url(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", revision=revision)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AuthStorageClient:
|
| 24 |
+
"""
|
| 25 |
+
Helper for managing user authentication data and permission requests
|
| 26 |
+
in a Hugging Face dataset repository.
|
| 27 |
+
|
| 28 |
+
Repo format:
|
| 29 |
+
- auth/users.json
|
| 30 |
+
- auth/requests.json
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, dataset_repo: str, hf_token: Optional[str] = None, revision: str = "main"):
|
| 34 |
+
if not dataset_repo:
|
| 35 |
+
raise ValueError("HF_DATASET_REPO is not set. Please configure the dataset repository id.")
|
| 36 |
+
self.dataset_repo = dataset_repo
|
| 37 |
+
self.revision = revision
|
| 38 |
+
self.api = HfApi(token=hf_token) if hf_token else HfApi()
|
| 39 |
+
|
| 40 |
+
def load_users(self) -> List[Dict[str, Any]]:
|
| 41 |
+
"""
|
| 42 |
+
Download and parse users.json from the dataset. If missing, return [].
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
url = build_dataset_resolve_url(self.dataset_repo, USERS_FILE_PATH, self.revision)
|
| 46 |
+
import requests
|
| 47 |
+
headers = {}
|
| 48 |
+
if self.api.token:
|
| 49 |
+
headers["Authorization"] = f"Bearer {self.api.token}"
|
| 50 |
+
resp = requests.get(url, timeout=10, headers=headers)
|
| 51 |
+
if resp.status_code == 200:
|
| 52 |
+
data = resp.json()
|
| 53 |
+
return data.get("users", [])
|
| 54 |
+
logger.info("Users file not found at %s (status %s). Initializing empty users list.", url, resp.status_code)
|
| 55 |
+
return []
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error("Failed to load users from HF: %s", str(e))
|
| 58 |
+
return []
|
| 59 |
+
|
| 60 |
+
def save_users(self, users: List[Dict[str, Any]]) -> None:
|
| 61 |
+
"""
|
| 62 |
+
Commit a new version of users.json to the dataset repo.
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
payload = json.dumps({"users": users}, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
| 66 |
+
operations = [
|
| 67 |
+
CommitOperationAdd(path_in_repo=USERS_FILE_PATH, path_or_fileobj=payload)
|
| 68 |
+
]
|
| 69 |
+
create_commit(
|
| 70 |
+
repo_id=self.dataset_repo,
|
| 71 |
+
repo_type="dataset",
|
| 72 |
+
operations=operations,
|
| 73 |
+
commit_message="Update users.json",
|
| 74 |
+
revision=self.revision,
|
| 75 |
+
token=self.api.token,
|
| 76 |
+
)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error("Failed to save users to HF: %s", str(e))
|
| 79 |
+
raise
|
| 80 |
+
|
| 81 |
+
def load_requests(self) -> List[Dict[str, Any]]:
|
| 82 |
+
"""
|
| 83 |
+
Download and parse requests.json from the dataset. If missing, return [].
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
url = build_dataset_resolve_url(self.dataset_repo, REQUESTS_FILE_PATH, self.revision)
|
| 87 |
+
import requests
|
| 88 |
+
headers = {}
|
| 89 |
+
if self.api.token:
|
| 90 |
+
headers["Authorization"] = f"Bearer {self.api.token}"
|
| 91 |
+
resp = requests.get(url, timeout=10, headers=headers)
|
| 92 |
+
if resp.status_code == 200:
|
| 93 |
+
data = resp.json()
|
| 94 |
+
return data.get("requests", [])
|
| 95 |
+
logger.info("Requests file not found at %s (status %s). Initializing empty requests list.", url, resp.status_code)
|
| 96 |
+
return []
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error("Failed to load requests from HF: %s", str(e))
|
| 99 |
+
return []
|
| 100 |
+
|
| 101 |
+
def save_requests(self, requests: List[Dict[str, Any]]) -> None:
|
| 102 |
+
"""
|
| 103 |
+
Commit a new version of requests.json to the dataset repo.
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
payload = json.dumps({"requests": requests}, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
| 107 |
+
operations = [
|
| 108 |
+
CommitOperationAdd(path_in_repo=REQUESTS_FILE_PATH, path_or_fileobj=payload)
|
| 109 |
+
]
|
| 110 |
+
create_commit(
|
| 111 |
+
repo_id=self.dataset_repo,
|
| 112 |
+
repo_type="dataset",
|
| 113 |
+
operations=operations,
|
| 114 |
+
commit_message="Update requests.json",
|
| 115 |
+
revision=self.revision,
|
| 116 |
+
token=self.api.token,
|
| 117 |
+
)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error("Failed to save requests to HF: %s", str(e))
|
| 120 |
+
raise
|
| 121 |
+
|
| 122 |
+
def add_user(self, email: str, password_hash: str) -> None:
|
| 123 |
+
"""
|
| 124 |
+
Add a new user to the users list.
|
| 125 |
+
"""
|
| 126 |
+
users = self.load_users()
|
| 127 |
+
# Check if user already exists
|
| 128 |
+
if any(user.get("email") == email for user in users):
|
| 129 |
+
raise ValueError(f"User with email {email} already exists")
|
| 130 |
+
|
| 131 |
+
users.append({
|
| 132 |
+
"email": email,
|
| 133 |
+
"password_hash": password_hash
|
| 134 |
+
})
|
| 135 |
+
self.save_users(users)
|
| 136 |
+
|
| 137 |
+
def get_user(self, email: str) -> Optional[Dict[str, Any]]:
|
| 138 |
+
"""
|
| 139 |
+
Get a user by email.
|
| 140 |
+
"""
|
| 141 |
+
users = self.load_users()
|
| 142 |
+
return next((user for user in users if user.get("email") == email), None)
|
| 143 |
+
|
| 144 |
+
def delete_user(self, email: str) -> None:
|
| 145 |
+
"""
|
| 146 |
+
Delete a user by email.
|
| 147 |
+
"""
|
| 148 |
+
users = self.load_users()
|
| 149 |
+
users = [user for user in users if user.get("email") != email]
|
| 150 |
+
self.save_users(users)
|
| 151 |
+
|
| 152 |
+
def add_request(self, name: str, email: str, reason: str) -> str:
|
| 153 |
+
"""
|
| 154 |
+
Add a new permission request. Returns the request ID.
|
| 155 |
+
"""
|
| 156 |
+
requests = self.load_requests()
|
| 157 |
+
request_id = str(uuid.uuid4())
|
| 158 |
+
|
| 159 |
+
new_request = {
|
| 160 |
+
"id": request_id,
|
| 161 |
+
"name": name,
|
| 162 |
+
"email": email,
|
| 163 |
+
"reason": reason,
|
| 164 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 165 |
+
"status": "pending",
|
| 166 |
+
"reviewed_at": None,
|
| 167 |
+
"rejection_reason": None
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
requests.append(new_request)
|
| 171 |
+
self.save_requests(requests)
|
| 172 |
+
return request_id
|
| 173 |
+
|
| 174 |
+
def get_request(self, request_id: str) -> Optional[Dict[str, Any]]:
|
| 175 |
+
"""
|
| 176 |
+
Get a request by ID.
|
| 177 |
+
"""
|
| 178 |
+
requests = self.load_requests()
|
| 179 |
+
return next((req for req in requests if req.get("id") == request_id), None)
|
| 180 |
+
|
| 181 |
+
def update_request_status(self, request_id: str, status: str, rejection_reason: Optional[str] = None) -> None:
|
| 182 |
+
"""
|
| 183 |
+
Update the status of a request.
|
| 184 |
+
"""
|
| 185 |
+
requests = self.load_requests()
|
| 186 |
+
for req in requests:
|
| 187 |
+
if req.get("id") == request_id:
|
| 188 |
+
req["status"] = status
|
| 189 |
+
req["reviewed_at"] = datetime.utcnow().isoformat()
|
| 190 |
+
if rejection_reason:
|
| 191 |
+
req["rejection_reason"] = rejection_reason
|
| 192 |
+
break
|
| 193 |
+
self.save_requests(requests)
|
| 194 |
+
|
| 195 |
+
def delete_request(self, request_id: str) -> None:
|
| 196 |
+
"""
|
| 197 |
+
Delete a request by ID.
|
| 198 |
+
"""
|
| 199 |
+
requests = self.load_requests()
|
| 200 |
+
requests = [req for req in requests if req.get("id") != request_id]
|
| 201 |
+
self.save_requests(requests)
|
email_service.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import smtplib
|
| 3 |
+
import logging
|
| 4 |
+
from email.mime.text import MIMEText
|
| 5 |
+
from email.mime.multipart import MIMEMultipart
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EmailService:
|
| 13 |
+
"""
|
| 14 |
+
Service for sending emails via SMTP.
|
| 15 |
+
Supports Gmail and other SMTP servers.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.smtp_host = os.getenv("SMTP_HOST", "smtp.gmail.com")
|
| 20 |
+
self.smtp_port = int(os.getenv("SMTP_PORT", "587"))
|
| 21 |
+
self.smtp_user = os.getenv("SMTP_USER", "")
|
| 22 |
+
self.smtp_password = os.getenv("SMTP_PASSWORD", "")
|
| 23 |
+
self.from_email = os.getenv("SMTP_FROM_EMAIL", self.smtp_user)
|
| 24 |
+
self.admin_email = os.getenv("ADMIN_EMAIL", "")
|
| 25 |
+
|
| 26 |
+
def _send_email(self, to_email: str, subject: str, body: str, is_html: bool = False) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Send an email via SMTP.
|
| 29 |
+
"""
|
| 30 |
+
if not self.smtp_user or not self.smtp_password:
|
| 31 |
+
logger.warning("SMTP credentials not configured. Email not sent.")
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
msg = MIMEMultipart()
|
| 36 |
+
msg['From'] = self.from_email
|
| 37 |
+
msg['To'] = to_email
|
| 38 |
+
msg['Subject'] = subject
|
| 39 |
+
|
| 40 |
+
if is_html:
|
| 41 |
+
msg.attach(MIMEText(body, 'html'))
|
| 42 |
+
else:
|
| 43 |
+
msg.attach(MIMEText(body, 'plain'))
|
| 44 |
+
|
| 45 |
+
with smtplib.SMTP(self.smtp_host, self.smtp_port) as server:
|
| 46 |
+
server.starttls()
|
| 47 |
+
server.login(self.smtp_user, self.smtp_password)
|
| 48 |
+
server.send_message(msg)
|
| 49 |
+
|
| 50 |
+
logger.info(f"Email sent successfully to {to_email}")
|
| 51 |
+
return True
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Failed to send email to {to_email}: {str(e)}")
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
def send_permission_request_notification(self, name: str, email: str, reason: str, timestamp: str) -> bool:
|
| 57 |
+
"""
|
| 58 |
+
Send email notification to admin when a permission request is submitted.
|
| 59 |
+
"""
|
| 60 |
+
if not self.admin_email:
|
| 61 |
+
logger.warning("ADMIN_EMAIL not configured. Notification not sent.")
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
subject = f"New Permission Request: {name}"
|
| 65 |
+
body = f"""
|
| 66 |
+
A new permission request has been submitted:
|
| 67 |
+
|
| 68 |
+
Name: {name}
|
| 69 |
+
Email: {email}
|
| 70 |
+
Reason: {reason}
|
| 71 |
+
Timestamp: {timestamp}
|
| 72 |
+
|
| 73 |
+
Please review the request in the admin interface.
|
| 74 |
+
"""
|
| 75 |
+
return self._send_email(self.admin_email, subject, body)
|
| 76 |
+
|
| 77 |
+
def send_approval_email(self, user_email: str, user_name: str, password: str) -> bool:
|
| 78 |
+
"""
|
| 79 |
+
Send approval email to user with their account credentials.
|
| 80 |
+
"""
|
| 81 |
+
subject = "Your Style Transfer Account Has Been Approved"
|
| 82 |
+
body = f"""
|
| 83 |
+
Hello {user_name},
|
| 84 |
+
|
| 85 |
+
Your permission request has been approved! Your account has been created.
|
| 86 |
+
|
| 87 |
+
Login Credentials:
|
| 88 |
+
Email: {user_email}
|
| 89 |
+
Password: {password}
|
| 90 |
+
|
| 91 |
+
You can now access the Neural Style Transfer application and create style transfers.
|
| 92 |
+
|
| 93 |
+
Please keep your password secure and do not share it with anyone.
|
| 94 |
+
|
| 95 |
+
Best regards,
|
| 96 |
+
Style Transfer Team
|
| 97 |
+
"""
|
| 98 |
+
return self._send_email(user_email, subject, body)
|
| 99 |
+
|
| 100 |
+
def send_rejection_email(self, user_email: str, user_name: str, reason: Optional[str] = None) -> bool:
|
| 101 |
+
"""
|
| 102 |
+
Send rejection email to user.
|
| 103 |
+
"""
|
| 104 |
+
subject = "Permission Request Status Update"
|
| 105 |
+
body = f"""
|
| 106 |
+
Hello {user_name},
|
| 107 |
+
|
| 108 |
+
Thank you for your interest in the Neural Style Transfer application.
|
| 109 |
+
|
| 110 |
+
Unfortunately, your permission request has been declined at this time.
|
| 111 |
+
"""
|
| 112 |
+
if reason:
|
| 113 |
+
body += f"\nReason: {reason}\n"
|
| 114 |
+
|
| 115 |
+
body += """
|
| 116 |
+
If you have any questions, please feel free to reach out.
|
| 117 |
+
|
| 118 |
+
Best regards,
|
| 119 |
+
Style Transfer Team
|
| 120 |
+
"""
|
| 121 |
+
return self._send_email(user_email, subject, body)
|
env.example
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy this file to `.env` in the same directory and fill in your values.
|
| 2 |
+
|
| 3 |
+
# Hugging Face Dataset repository to store uploads/results/gallery
|
| 4 |
+
# Format: <username>/<dataset-name>
|
| 5 |
+
HF_DATASET_REPO=your-username/style-transfer-data
|
| 6 |
+
|
| 7 |
+
# Hugging Face access token with write permission to the dataset
|
| 8 |
+
# Create at: https://huggingface.co/settings/tokens
|
| 9 |
+
HF_TOKEN=hf_your_secret_token
|
| 10 |
+
|
| 11 |
+
# Comma-separated list of allowed origins for CORS
|
| 12 |
+
# Include your local dev, GitHub Pages, and deployed frontend URLs
|
| 13 |
+
ALLOWED_ORIGINS=http://localhost:4200,https://your-username.github.io
|
| 14 |
+
|
| 15 |
+
# Master password for admin access (required)
|
| 16 |
+
MASTER_PASSWORD=your-master-password-here
|
| 17 |
+
|
| 18 |
+
# Admin email for receiving permission request notifications (required)
|
| 19 |
+
ADMIN_EMAIL=admin@example.com
|
| 20 |
+
|
| 21 |
+
# Email service configuration (for sending notifications)
|
| 22 |
+
# SMTP settings (for Gmail or other SMTP servers)
|
| 23 |
+
SMTP_HOST=smtp.gmail.com
|
| 24 |
+
SMTP_PORT=587
|
| 25 |
+
SMTP_USER=your-email@gmail.com
|
| 26 |
+
SMTP_PASSWORD=your-app-password
|
| 27 |
+
SMTP_FROM_EMAIL=your-email@gmail.com
|
| 28 |
+
|
| 29 |
+
# Optional: JWT secret key (defaults to MASTER_PASSWORD if not set)
|
| 30 |
+
# JWT_SECRET_KEY=your-jwt-secret-key
|
main.py
ADDED
|
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import json
|
| 4 |
+
import secrets
|
| 5 |
+
from fastapi import FastAPI, UploadFile, File, Form, BackgroundTasks, Depends, HTTPException, status
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
import logging
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
from typing import Dict, Optional, Any
|
| 11 |
+
import asyncio
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import time
|
| 14 |
+
import math
|
| 15 |
+
import tempfile
|
| 16 |
+
|
| 17 |
+
from style_transfer import transfer_style
|
| 18 |
+
from storage_hf import (
|
| 19 |
+
HFStorageClient,
|
| 20 |
+
build_dataset_resolve_url
|
| 21 |
+
)
|
| 22 |
+
from auth import (
|
| 23 |
+
verify_password,
|
| 24 |
+
get_password_hash,
|
| 25 |
+
create_access_token,
|
| 26 |
+
verify_token,
|
| 27 |
+
get_current_user
|
| 28 |
+
)
|
| 29 |
+
from auth_storage import AuthStorageClient
|
| 30 |
+
from email_service import EmailService
|
| 31 |
+
|
| 32 |
+
# Set up logging
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Custom JSON encoder to handle infinity
|
| 37 |
+
class CustomJSONEncoder(json.JSONEncoder):
|
| 38 |
+
def default(self, obj):
|
| 39 |
+
if isinstance(obj, float):
|
| 40 |
+
if math.isinf(obj):
|
| 41 |
+
return "Infinity" if obj > 0 else "-Infinity"
|
| 42 |
+
if math.isnan(obj):
|
| 43 |
+
return "NaN"
|
| 44 |
+
return super().default(obj)
|
| 45 |
+
|
| 46 |
+
# Custom JSONResponse that uses our custom encoder
|
| 47 |
+
class CustomJSONResponse(JSONResponse):
|
| 48 |
+
def render(self, content: Any) -> bytes:
|
| 49 |
+
return json.dumps(
|
| 50 |
+
content,
|
| 51 |
+
ensure_ascii=False,
|
| 52 |
+
allow_nan=False,
|
| 53 |
+
indent=None,
|
| 54 |
+
separators=(",", ":"),
|
| 55 |
+
cls=CustomJSONEncoder,
|
| 56 |
+
).encode("utf-8")
|
| 57 |
+
|
| 58 |
+
# Initialize FastAPI
|
| 59 |
+
app = FastAPI(title="Neural Style Transfer API", default_response_class=CustomJSONResponse)
|
| 60 |
+
|
| 61 |
+
# Add request logging middleware
|
| 62 |
+
@app.middleware("http")
|
| 63 |
+
async def log_requests(request, call_next):
|
| 64 |
+
logger.info(f"Incoming request: {request.method} {request.url.path} from {request.client.host if request.client else 'unknown'}")
|
| 65 |
+
response = await call_next(request)
|
| 66 |
+
logger.info(f"Response: {request.method} {request.url.path} -> {response.status_code}")
|
| 67 |
+
return response
|
| 68 |
+
|
| 69 |
+
# HF storage configuration
|
| 70 |
+
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "") # e.g. username/style-transfer-data
|
| 71 |
+
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 72 |
+
storage_client = HFStorageClient(dataset_repo=HF_DATASET_REPO, hf_token=HF_TOKEN)
|
| 73 |
+
|
| 74 |
+
# Auth storage and email service
|
| 75 |
+
auth_storage = AuthStorageClient(dataset_repo=HF_DATASET_REPO, hf_token=HF_TOKEN)
|
| 76 |
+
email_service = EmailService()
|
| 77 |
+
MASTER_PASSWORD = os.getenv("MASTER_PASSWORD", "")
|
| 78 |
+
|
| 79 |
+
# Setup CORS
|
| 80 |
+
allowed_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:4200").split(",")
|
| 81 |
+
allowed_origins_clean = [origin.strip() for origin in allowed_origins if origin.strip()]
|
| 82 |
+
logger.info(f"CORS allowed origins: {allowed_origins_clean}")
|
| 83 |
+
app.add_middleware(
|
| 84 |
+
CORSMiddleware,
|
| 85 |
+
allow_origins=allowed_origins_clean,
|
| 86 |
+
allow_credentials=True,
|
| 87 |
+
allow_methods=["*"],
|
| 88 |
+
allow_headers=["*"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
"""
|
| 92 |
+
With Hugging Face storage, images are served via absolute URLs from the
|
| 93 |
+
dataset CDN, so we no longer mount local static directories.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
# Keep track of running jobs using asyncio.Queue for thread-safe updates
|
| 97 |
+
job_queues = {}
|
| 98 |
+
active_jobs = {}
|
| 99 |
+
|
| 100 |
+
class StyleTransferProgress(BaseModel):
|
| 101 |
+
job_id: str
|
| 102 |
+
status: str
|
| 103 |
+
progress: Optional[int] = 0
|
| 104 |
+
style_loss: Optional[float] = None
|
| 105 |
+
content_loss: Optional[float] = None
|
| 106 |
+
result_url: Optional[str] = None
|
| 107 |
+
error: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
# Auth models
|
| 110 |
+
class LoginRequest(BaseModel):
|
| 111 |
+
email: Optional[str] = None
|
| 112 |
+
password: Optional[str] = None
|
| 113 |
+
master_password: Optional[str] = None
|
| 114 |
+
|
| 115 |
+
class PermissionRequest(BaseModel):
|
| 116 |
+
name: str
|
| 117 |
+
email: str
|
| 118 |
+
reason: str
|
| 119 |
+
|
| 120 |
+
class CreateUserRequest(BaseModel):
|
| 121 |
+
email: str
|
| 122 |
+
password: str
|
| 123 |
+
|
| 124 |
+
class ApproveRequestData(BaseModel):
|
| 125 |
+
password: Optional[str] = None # Optional custom password, otherwise auto-generated
|
| 126 |
+
|
| 127 |
+
class RejectRequestData(BaseModel):
|
| 128 |
+
reason: Optional[str] = None
|
| 129 |
+
|
| 130 |
+
# Helper function to generate unique file paths
|
| 131 |
+
def get_unique_filename(directory, extension=".jpg"):
|
| 132 |
+
return os.path.join(directory, f"{uuid.uuid4()}{extension}")
|
| 133 |
+
|
| 134 |
+
def load_gallery():
|
| 135 |
+
try:
|
| 136 |
+
return storage_client.load_gallery()
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error loading gallery data from HF: {str(e)}")
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
def save_gallery(gallery_data):
|
| 142 |
+
try:
|
| 143 |
+
storage_client.save_gallery(gallery_data)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Error saving gallery data to HF: {str(e)}")
|
| 146 |
+
|
| 147 |
+
# Background task for style transfer
|
| 148 |
+
async def run_style_transfer_task(
|
| 149 |
+
job_id: str,
|
| 150 |
+
content_local_path: str,
|
| 151 |
+
style_local_path: str,
|
| 152 |
+
output_local_path: str,
|
| 153 |
+
style_weight: float,
|
| 154 |
+
content_weight: float,
|
| 155 |
+
num_steps: int,
|
| 156 |
+
layer_weights: Dict[str, float]
|
| 157 |
+
):
|
| 158 |
+
try:
|
| 159 |
+
# Create a queue for this job if it doesn't exist
|
| 160 |
+
if job_id not in job_queues:
|
| 161 |
+
job_queues[job_id] = asyncio.Queue()
|
| 162 |
+
|
| 163 |
+
queue = job_queues[job_id]
|
| 164 |
+
start_time = time.time()
|
| 165 |
+
best_loss = float('inf')
|
| 166 |
+
style_loss = 0
|
| 167 |
+
content_loss = 0
|
| 168 |
+
|
| 169 |
+
# Update job status
|
| 170 |
+
await queue.put({
|
| 171 |
+
"status": "processing",
|
| 172 |
+
"progress": 0,
|
| 173 |
+
"style_loss": None,
|
| 174 |
+
"content_loss": None
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
# Define a callback that will update the job status
|
| 178 |
+
def on_progress(progress):
|
| 179 |
+
nonlocal style_loss, content_loss, best_loss
|
| 180 |
+
# Calculate total loss as the sum of style and content loss
|
| 181 |
+
total_loss = progress["style_loss"] + progress["content_loss"]
|
| 182 |
+
|
| 183 |
+
# Update the best loss if this one is better
|
| 184 |
+
if total_loss < best_loss:
|
| 185 |
+
best_loss = total_loss
|
| 186 |
+
|
| 187 |
+
progress_data = {
|
| 188 |
+
"status": "processing",
|
| 189 |
+
"progress": progress["iteration"] / num_steps * 100,
|
| 190 |
+
"style_loss": progress["style_loss"],
|
| 191 |
+
"content_loss": progress["content_loss"]
|
| 192 |
+
}
|
| 193 |
+
style_loss = progress["style_loss"]
|
| 194 |
+
content_loss = progress["content_loss"]
|
| 195 |
+
|
| 196 |
+
# Use asyncio.run_coroutine_threadsafe to safely put data in the queue from a different thread
|
| 197 |
+
loop = asyncio.get_event_loop()
|
| 198 |
+
asyncio.run_coroutine_threadsafe(queue.put(progress_data), loop)
|
| 199 |
+
|
| 200 |
+
# Run the style transfer
|
| 201 |
+
result_path, model_best_loss = transfer_style(
|
| 202 |
+
content_path=content_local_path,
|
| 203 |
+
style_path=style_local_path,
|
| 204 |
+
output_path=output_local_path,
|
| 205 |
+
style_weight=style_weight,
|
| 206 |
+
content_weight=content_weight,
|
| 207 |
+
num_steps=num_steps,
|
| 208 |
+
layer_weights=layer_weights,
|
| 209 |
+
progress_callback=on_progress
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
processing_time = time.time() - start_time
|
| 213 |
+
|
| 214 |
+
# If best_loss is still infinity, use the model's best loss or the sum of final losses
|
| 215 |
+
if math.isinf(best_loss):
|
| 216 |
+
best_loss = model_best_loss if not math.isinf(model_best_loss) else style_loss + content_loss
|
| 217 |
+
|
| 218 |
+
# Upload artifacts to Hugging Face dataset and build absolute URLs
|
| 219 |
+
date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
|
| 220 |
+
base_prefix = f"runs/{date_prefix}/{job_id}"
|
| 221 |
+
|
| 222 |
+
content_ds_path = storage_client.upload_file(
|
| 223 |
+
local_path=content_local_path,
|
| 224 |
+
dst_path=f"{base_prefix}/content.jpg"
|
| 225 |
+
)
|
| 226 |
+
style_ds_path = storage_client.upload_file(
|
| 227 |
+
local_path=style_local_path,
|
| 228 |
+
dst_path=f"{base_prefix}/style.jpg"
|
| 229 |
+
)
|
| 230 |
+
result_ds_path = storage_client.upload_file(
|
| 231 |
+
local_path=output_local_path,
|
| 232 |
+
dst_path=f"{base_prefix}/result.jpg"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
content_url = build_dataset_resolve_url(storage_client.dataset_repo, content_ds_path)
|
| 236 |
+
style_url = build_dataset_resolve_url(storage_client.dataset_repo, style_ds_path)
|
| 237 |
+
result_url = build_dataset_resolve_url(storage_client.dataset_repo, result_ds_path)
|
| 238 |
+
|
| 239 |
+
# Save to gallery
|
| 240 |
+
gallery_item = {
|
| 241 |
+
"id": job_id,
|
| 242 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 243 |
+
"contentImageUrl": content_url,
|
| 244 |
+
"styleImageUrl": style_url,
|
| 245 |
+
"resultImageUrl": result_url,
|
| 246 |
+
"bestLoss": style_loss + content_loss,
|
| 247 |
+
"styleLoss": style_loss,
|
| 248 |
+
"contentLoss": content_loss,
|
| 249 |
+
"processingTime": processing_time,
|
| 250 |
+
"parameters": {
|
| 251 |
+
"styleWeight": style_weight,
|
| 252 |
+
"contentWeight": content_weight,
|
| 253 |
+
"numSteps": num_steps,
|
| 254 |
+
"layerWeights": layer_weights
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
gallery = load_gallery()
|
| 259 |
+
gallery.append(gallery_item)
|
| 260 |
+
save_gallery(gallery)
|
| 261 |
+
|
| 262 |
+
# Update job status with result
|
| 263 |
+
await queue.put({
|
| 264 |
+
"status": "completed",
|
| 265 |
+
"progress": 100,
|
| 266 |
+
"style_loss": style_loss,
|
| 267 |
+
"content_loss": content_loss,
|
| 268 |
+
"result_url": result_url
|
| 269 |
+
})
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.error(f"Error in style transfer: {str(e)}")
|
| 273 |
+
await queue.put({
|
| 274 |
+
"status": "failed",
|
| 275 |
+
"error": str(e)
|
| 276 |
+
})
|
| 277 |
+
finally:
|
| 278 |
+
# Keep the last status update in active_jobs
|
| 279 |
+
try:
|
| 280 |
+
last_status = queue.get_nowait()
|
| 281 |
+
active_jobs[job_id] = last_status
|
| 282 |
+
except asyncio.QueueEmpty:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
@app.post("/api/transfer")
|
| 286 |
+
async def create_style_transfer(
|
| 287 |
+
background_tasks: BackgroundTasks,
|
| 288 |
+
content_image: UploadFile = File(...),
|
| 289 |
+
style_image: UploadFile = File(...),
|
| 290 |
+
style_weight: float = Form(1000000.0),
|
| 291 |
+
content_weight: float = Form(1.0),
|
| 292 |
+
num_steps: int = Form(300),
|
| 293 |
+
layer_weights: str = Form("{}"),
|
| 294 |
+
current_user: Dict[str, Any] = Depends(get_current_user),
|
| 295 |
+
):
|
| 296 |
+
try:
|
| 297 |
+
# Parse layer weights from JSON string
|
| 298 |
+
layer_weights_dict = json.loads(layer_weights)
|
| 299 |
+
|
| 300 |
+
# Save uploaded files temporarily to local disk for processing
|
| 301 |
+
temp_dir = tempfile.gettempdir()
|
| 302 |
+
content_path = get_unique_filename(temp_dir)
|
| 303 |
+
style_path = get_unique_filename(temp_dir)
|
| 304 |
+
output_path = get_unique_filename(temp_dir)
|
| 305 |
+
|
| 306 |
+
with open(content_path, "wb") as content_file:
|
| 307 |
+
content_file.write(await content_image.read())
|
| 308 |
+
|
| 309 |
+
with open(style_path, "wb") as style_file:
|
| 310 |
+
style_file.write(await style_image.read())
|
| 311 |
+
|
| 312 |
+
# Create a unique job ID
|
| 313 |
+
job_id = str(uuid.uuid4())
|
| 314 |
+
|
| 315 |
+
# Initialize job status
|
| 316 |
+
active_jobs[job_id] = {
|
| 317 |
+
"status": "pending",
|
| 318 |
+
"progress": 0,
|
| 319 |
+
"style_loss": None,
|
| 320 |
+
"content_loss": None,
|
| 321 |
+
"result_url": None,
|
| 322 |
+
"error": None
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
# Start style transfer in the background
|
| 326 |
+
background_tasks.add_task(
|
| 327 |
+
run_style_transfer_task,
|
| 328 |
+
job_id,
|
| 329 |
+
content_path,
|
| 330 |
+
style_path,
|
| 331 |
+
output_path,
|
| 332 |
+
style_weight,
|
| 333 |
+
content_weight,
|
| 334 |
+
num_steps,
|
| 335 |
+
layer_weights_dict
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return {
|
| 339 |
+
"job_id": job_id,
|
| 340 |
+
"status": "pending"
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
except Exception as e:
|
| 344 |
+
logger.error(f"Error creating style transfer: {str(e)}")
|
| 345 |
+
return CustomJSONResponse(
|
| 346 |
+
status_code=500,
|
| 347 |
+
content={"error": str(e)}
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
@app.get("/api/transfer/{job_id}")
|
| 351 |
+
async def get_transfer_status(job_id: str):
|
| 352 |
+
if job_id not in active_jobs and job_id not in job_queues:
|
| 353 |
+
return CustomJSONResponse(
|
| 354 |
+
status_code=404,
|
| 355 |
+
content={"error": "Job not found"}
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Try to get the latest status from the queue
|
| 359 |
+
if job_id in job_queues:
|
| 360 |
+
try:
|
| 361 |
+
# Get the latest status without removing it from the queue
|
| 362 |
+
status = job_queues[job_id].get_nowait()
|
| 363 |
+
job_queues[job_id].put_nowait(status) # Put it back
|
| 364 |
+
active_jobs[job_id] = status # Update active_jobs with latest status
|
| 365 |
+
except asyncio.QueueEmpty:
|
| 366 |
+
# If queue is empty, use the last known status from active_jobs
|
| 367 |
+
pass
|
| 368 |
+
|
| 369 |
+
job_status = active_jobs[job_id]
|
| 370 |
+
|
| 371 |
+
# Return appropriate response based on job status
|
| 372 |
+
if job_status["status"] == "completed" and job_status.get("result_url"):
|
| 373 |
+
# Clean up the queue for completed jobs
|
| 374 |
+
if job_id in job_queues:
|
| 375 |
+
del job_queues[job_id]
|
| 376 |
+
return {
|
| 377 |
+
"job_id": job_id,
|
| 378 |
+
"status": "completed",
|
| 379 |
+
"progress": 100,
|
| 380 |
+
"style_loss": job_status.get("style_loss"),
|
| 381 |
+
"content_loss": job_status.get("content_loss"),
|
| 382 |
+
"result_url": job_status["result_url"]
|
| 383 |
+
}
|
| 384 |
+
elif job_status["status"] == "failed":
|
| 385 |
+
# Clean up the queue for failed jobs
|
| 386 |
+
if job_id in job_queues:
|
| 387 |
+
del job_queues[job_id]
|
| 388 |
+
return {
|
| 389 |
+
"job_id": job_id,
|
| 390 |
+
"status": "failed",
|
| 391 |
+
"error": job_status.get("error", "Unknown error")
|
| 392 |
+
}
|
| 393 |
+
else:
|
| 394 |
+
return {
|
| 395 |
+
"job_id": job_id,
|
| 396 |
+
"status": job_status["status"],
|
| 397 |
+
"progress": job_status.get("progress", 0),
|
| 398 |
+
"style_loss": job_status.get("style_loss"),
|
| 399 |
+
"content_loss": job_status.get("content_loss")
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
@app.get("/api/health")
|
| 403 |
+
async def health_check():
|
| 404 |
+
return {"status": "ok"}
|
| 405 |
+
|
| 406 |
+
@app.get("/api/gallery")
|
| 407 |
+
async def get_gallery_items():
|
| 408 |
+
try:
|
| 409 |
+
gallery = load_gallery()
|
| 410 |
+
# Replace any infinity values before sending response
|
| 411 |
+
for item in gallery:
|
| 412 |
+
if 'bestLoss' in item and isinstance(item['bestLoss'], float) and math.isinf(item['bestLoss']):
|
| 413 |
+
item['bestLoss'] = 999999999 if item['bestLoss'] > 0 else -999999999
|
| 414 |
+
if 'styleLoss' in item and isinstance(item['styleLoss'], float) and math.isinf(item['styleLoss']):
|
| 415 |
+
item['styleLoss'] = 999999999 if item['styleLoss'] > 0 else -999999999
|
| 416 |
+
if 'contentLoss' in item and isinstance(item['contentLoss'], float) and math.isinf(item['contentLoss']):
|
| 417 |
+
item['contentLoss'] = 999999999 if item['contentLoss'] > 0 else -999999999
|
| 418 |
+
return gallery
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.error(f"Error getting gallery items: {str(e)}")
|
| 421 |
+
return CustomJSONResponse(
|
| 422 |
+
status_code=500,
|
| 423 |
+
content={"error": str(e)}
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
@app.get("/api/gallery/{item_id}")
|
| 427 |
+
async def get_gallery_item(item_id: str):
|
| 428 |
+
try:
|
| 429 |
+
gallery = load_gallery()
|
| 430 |
+
item = next((item for item in gallery if item["id"] == item_id), None)
|
| 431 |
+
if item is None:
|
| 432 |
+
return CustomJSONResponse(status_code=404, content={"error": "Item not found"})
|
| 433 |
+
|
| 434 |
+
# Replace any infinity values before sending response
|
| 435 |
+
if 'bestLoss' in item and isinstance(item['bestLoss'], float) and math.isinf(item['bestLoss']):
|
| 436 |
+
item['bestLoss'] = 999999999 if item['bestLoss'] > 0 else -999999999
|
| 437 |
+
if 'styleLoss' in item and isinstance(item['styleLoss'], float) and math.isinf(item['styleLoss']):
|
| 438 |
+
item['styleLoss'] = 999999999 if item['styleLoss'] > 0 else -999999999
|
| 439 |
+
if 'contentLoss' in item and isinstance(item['contentLoss'], float) and math.isinf(item['contentLoss']):
|
| 440 |
+
item['contentLoss'] = 999999999 if item['contentLoss'] > 0 else -999999999
|
| 441 |
+
|
| 442 |
+
return item
|
| 443 |
+
except Exception as e:
|
| 444 |
+
logger.error(f"Error getting gallery item: {str(e)}")
|
| 445 |
+
return CustomJSONResponse(
|
| 446 |
+
status_code=500,
|
| 447 |
+
content={"error": str(e)}
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
@app.delete("/api/gallery/{item_id}")
|
| 451 |
+
async def delete_gallery_item(
|
| 452 |
+
item_id: str,
|
| 453 |
+
current_user: Dict[str, Any] = Depends(get_current_user),
|
| 454 |
+
):
|
| 455 |
+
try:
|
| 456 |
+
gallery = load_gallery()
|
| 457 |
+
item_to_delete = next((item for item in gallery if item["id"] == item_id), None)
|
| 458 |
+
|
| 459 |
+
if not item_to_delete:
|
| 460 |
+
return CustomJSONResponse(
|
| 461 |
+
status_code=404,
|
| 462 |
+
content={"error": "Item not found"}
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Remove from gallery first
|
| 466 |
+
gallery = [item for item in gallery if item["id"] != item_id]
|
| 467 |
+
save_gallery(gallery)
|
| 468 |
+
|
| 469 |
+
# Attempt to delete artifacts from dataset
|
| 470 |
+
try:
|
| 471 |
+
storage_client.delete_run_artifacts(item_to_delete)
|
| 472 |
+
except Exception as e:
|
| 473 |
+
logger.error(f"Error deleting dataset artifacts for {item_id}: {str(e)}")
|
| 474 |
+
|
| 475 |
+
return {"status": "success"}
|
| 476 |
+
except Exception as e:
|
| 477 |
+
logger.error(f"Error deleting gallery item: {str(e)}")
|
| 478 |
+
return CustomJSONResponse(
|
| 479 |
+
status_code=500,
|
| 480 |
+
content={"error": str(e)}
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Authentication endpoints
|
| 484 |
+
@app.post("/api/auth/login")
|
| 485 |
+
async def login(login_request: LoginRequest):
|
| 486 |
+
"""
|
| 487 |
+
Login with email/password or master password.
|
| 488 |
+
"""
|
| 489 |
+
logger.info(f"Login attempt - email: {login_request.email}, has_master_password: {bool(login_request.master_password)}")
|
| 490 |
+
try:
|
| 491 |
+
# Check master password first (doesn't require email or password)
|
| 492 |
+
if login_request.master_password and MASTER_PASSWORD and login_request.master_password == MASTER_PASSWORD:
|
| 493 |
+
access_token = create_access_token(data={"email": None, "is_master": True})
|
| 494 |
+
return {"access_token": access_token, "token_type": "bearer", "is_master": True}
|
| 495 |
+
|
| 496 |
+
# Check user email/password (requires both email and password)
|
| 497 |
+
if login_request.email and login_request.password:
|
| 498 |
+
user = auth_storage.get_user(login_request.email)
|
| 499 |
+
if user and verify_password(login_request.password, user["password_hash"]):
|
| 500 |
+
access_token = create_access_token(data={"email": login_request.email, "is_master": False})
|
| 501 |
+
return {"access_token": access_token, "token_type": "bearer", "is_master": False, "email": login_request.email}
|
| 502 |
+
|
| 503 |
+
raise HTTPException(
|
| 504 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 505 |
+
detail="Incorrect email/password or master password"
|
| 506 |
+
)
|
| 507 |
+
except HTTPException:
|
| 508 |
+
raise
|
| 509 |
+
except Exception as e:
|
| 510 |
+
logger.error(f"Error in login: {str(e)}")
|
| 511 |
+
raise HTTPException(
|
| 512 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 513 |
+
detail="Login failed"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
@app.post("/api/auth/requests")
|
| 517 |
+
async def submit_permission_request(request: PermissionRequest):
|
| 518 |
+
"""
|
| 519 |
+
Submit a permission request. Public endpoint.
|
| 520 |
+
"""
|
| 521 |
+
try:
|
| 522 |
+
request_id = auth_storage.add_request(
|
| 523 |
+
name=request.name,
|
| 524 |
+
email=request.email,
|
| 525 |
+
reason=request.reason
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# Get the request to get timestamp
|
| 529 |
+
req = auth_storage.get_request(request_id)
|
| 530 |
+
if req:
|
| 531 |
+
# Send email notification to admin
|
| 532 |
+
email_service.send_permission_request_notification(
|
| 533 |
+
name=request.name,
|
| 534 |
+
email=request.email,
|
| 535 |
+
reason=request.reason,
|
| 536 |
+
timestamp=req.get("timestamp", "")
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
return {"request_id": request_id, "status": "submitted"}
|
| 540 |
+
except Exception as e:
|
| 541 |
+
logger.error(f"Error submitting permission request: {str(e)}")
|
| 542 |
+
raise HTTPException(
|
| 543 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 544 |
+
detail="Failed to submit request"
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
def verify_master_password(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
| 548 |
+
"""Verify that the current user is using master password."""
|
| 549 |
+
if not user.get("is_master"):
|
| 550 |
+
raise HTTPException(
|
| 551 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 552 |
+
detail="Master password required"
|
| 553 |
+
)
|
| 554 |
+
return user
|
| 555 |
+
|
| 556 |
+
@app.get("/api/auth/requests")
|
| 557 |
+
async def list_permission_requests(
|
| 558 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 559 |
+
):
|
| 560 |
+
"""List all permission requests. Admin only."""
|
| 561 |
+
try:
|
| 562 |
+
requests = auth_storage.load_requests()
|
| 563 |
+
return {"requests": requests}
|
| 564 |
+
except Exception as e:
|
| 565 |
+
logger.error(f"Error listing requests: {str(e)}")
|
| 566 |
+
raise HTTPException(
|
| 567 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 568 |
+
detail="Failed to list requests"
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
@app.post("/api/auth/requests/{request_id}/approve")
|
| 572 |
+
async def approve_request(
|
| 573 |
+
request_id: str,
|
| 574 |
+
approve_data: ApproveRequestData,
|
| 575 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 576 |
+
):
|
| 577 |
+
"""Approve a permission request and create user account. Admin only."""
|
| 578 |
+
try:
|
| 579 |
+
req = auth_storage.get_request(request_id)
|
| 580 |
+
if not req:
|
| 581 |
+
raise HTTPException(status_code=404, detail="Request not found")
|
| 582 |
+
|
| 583 |
+
if req.get("status") != "pending":
|
| 584 |
+
raise HTTPException(status_code=400, detail="Request already processed")
|
| 585 |
+
|
| 586 |
+
# Generate password if not provided
|
| 587 |
+
password = approve_data.password or secrets.token_urlsafe(12)
|
| 588 |
+
password_hash = get_password_hash(password)
|
| 589 |
+
|
| 590 |
+
# Create user account
|
| 591 |
+
try:
|
| 592 |
+
auth_storage.add_user(email=req["email"], password_hash=password_hash)
|
| 593 |
+
except ValueError as e:
|
| 594 |
+
# User might already exist
|
| 595 |
+
logger.warning(f"User creation warning: {str(e)}")
|
| 596 |
+
|
| 597 |
+
# Update request status
|
| 598 |
+
auth_storage.update_request_status(request_id, "approved")
|
| 599 |
+
|
| 600 |
+
# Send approval email
|
| 601 |
+
email_service.send_approval_email(
|
| 602 |
+
user_email=req["email"],
|
| 603 |
+
user_name=req["name"],
|
| 604 |
+
password=password
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
return {"status": "approved", "email": req["email"]}
|
| 608 |
+
except HTTPException:
|
| 609 |
+
raise
|
| 610 |
+
except Exception as e:
|
| 611 |
+
logger.error(f"Error approving request: {str(e)}")
|
| 612 |
+
raise HTTPException(
|
| 613 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 614 |
+
detail="Failed to approve request"
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
@app.post("/api/auth/requests/{request_id}/reject")
|
| 618 |
+
async def reject_request(
|
| 619 |
+
request_id: str,
|
| 620 |
+
reject_data: RejectRequestData,
|
| 621 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 622 |
+
):
|
| 623 |
+
"""Reject a permission request. Admin only."""
|
| 624 |
+
try:
|
| 625 |
+
req = auth_storage.get_request(request_id)
|
| 626 |
+
if not req:
|
| 627 |
+
raise HTTPException(status_code=404, detail="Request not found")
|
| 628 |
+
|
| 629 |
+
if req.get("status") != "pending":
|
| 630 |
+
raise HTTPException(status_code=400, detail="Request already processed")
|
| 631 |
+
|
| 632 |
+
# Update request status
|
| 633 |
+
auth_storage.update_request_status(request_id, "rejected", reject_data.reason)
|
| 634 |
+
|
| 635 |
+
# Send rejection email
|
| 636 |
+
email_service.send_rejection_email(
|
| 637 |
+
user_email=req["email"],
|
| 638 |
+
user_name=req["name"],
|
| 639 |
+
reason=reject_data.reason
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
return {"status": "rejected"}
|
| 643 |
+
except HTTPException:
|
| 644 |
+
raise
|
| 645 |
+
except Exception as e:
|
| 646 |
+
logger.error(f"Error rejecting request: {str(e)}")
|
| 647 |
+
raise HTTPException(
|
| 648 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 649 |
+
detail="Failed to reject request"
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
@app.delete("/api/auth/requests/{request_id}")
|
| 653 |
+
async def delete_request(
|
| 654 |
+
request_id: str,
|
| 655 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 656 |
+
):
|
| 657 |
+
"""Delete a permission request without sending email. Admin only."""
|
| 658 |
+
try:
|
| 659 |
+
auth_storage.delete_request(request_id)
|
| 660 |
+
return {"status": "deleted"}
|
| 661 |
+
except Exception as e:
|
| 662 |
+
logger.error(f"Error deleting request: {str(e)}")
|
| 663 |
+
raise HTTPException(
|
| 664 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 665 |
+
detail="Failed to delete request"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
@app.post("/api/auth/users")
|
| 669 |
+
async def create_user(
|
| 670 |
+
user_request: CreateUserRequest,
|
| 671 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 672 |
+
):
|
| 673 |
+
"""Create a new user. Admin only."""
|
| 674 |
+
try:
|
| 675 |
+
password_hash = get_password_hash(user_request.password)
|
| 676 |
+
auth_storage.add_user(email=user_request.email, password_hash=password_hash)
|
| 677 |
+
return {"status": "created", "email": user_request.email}
|
| 678 |
+
except ValueError as e:
|
| 679 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 680 |
+
except Exception as e:
|
| 681 |
+
logger.error(f"Error creating user: {str(e)}")
|
| 682 |
+
raise HTTPException(
|
| 683 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 684 |
+
detail="Failed to create user"
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
@app.get("/api/auth/users")
|
| 688 |
+
async def list_users(
|
| 689 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 690 |
+
):
|
| 691 |
+
"""List all users. Admin only."""
|
| 692 |
+
try:
|
| 693 |
+
users = auth_storage.load_users()
|
| 694 |
+
# Don't return password hashes
|
| 695 |
+
user_list = [{"email": user["email"]} for user in users]
|
| 696 |
+
return {"users": user_list}
|
| 697 |
+
except Exception as e:
|
| 698 |
+
logger.error(f"Error listing users: {str(e)}")
|
| 699 |
+
raise HTTPException(
|
| 700 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 701 |
+
detail="Failed to list users"
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
@app.delete("/api/auth/users/{email}")
|
| 705 |
+
async def delete_user(
|
| 706 |
+
email: str,
|
| 707 |
+
admin_user: Dict[str, Any] = Depends(verify_master_password)
|
| 708 |
+
):
|
| 709 |
+
"""Delete a user. Admin only."""
|
| 710 |
+
try:
|
| 711 |
+
auth_storage.delete_user(email)
|
| 712 |
+
return {"status": "deleted", "email": email}
|
| 713 |
+
except Exception as e:
|
| 714 |
+
logger.error(f"Error deleting user: {str(e)}")
|
| 715 |
+
raise HTTPException(
|
| 716 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 717 |
+
detail="Failed to delete user"
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
if __name__ == "__main__":
|
| 721 |
+
import uvicorn
|
| 722 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
pillow>=10.0.0
|
| 4 |
+
numpy>=1.25.2
|
| 5 |
+
fastapi==0.103.1
|
| 6 |
+
uvicorn==0.23.2
|
| 7 |
+
python-multipart==0.0.6
|
| 8 |
+
aiofiles==23.2.1
|
| 9 |
+
matplotlib>=3.7.2
|
| 10 |
+
requests==2.31.0
|
| 11 |
+
python-dotenv==1.0.0
|
| 12 |
+
huggingface_hub>=0.26.0
|
| 13 |
+
python-jose[cryptography]>=3.3.0
|
| 14 |
+
passlib[bcrypt]>=1.7.4
|
run.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uvicorn
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# Create necessary directories
|
| 5 |
+
os.makedirs("uploads", exist_ok=True)
|
| 6 |
+
os.makedirs("results", exist_ok=True)
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
print("Starting Neural Style Transfer API")
|
| 10 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860)
|
storage_hf.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete, create_commit, hf_hub_url
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
GALLERY_FILE_PATH = "gallery/gallery.json"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_dataset_resolve_url(repo_id: str, path_in_repo: str, revision: str = "main") -> str:
|
| 15 |
+
"""
|
| 16 |
+
Build a CDN-resolved URL for a file stored in a Hugging Face dataset repo.
|
| 17 |
+
"""
|
| 18 |
+
return hf_hub_url(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", revision=revision)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HFStorageClient:
|
| 22 |
+
"""
|
| 23 |
+
Simple helper around huggingface_hub for storing run artifacts and gallery metadata
|
| 24 |
+
in a Dataset repository.
|
| 25 |
+
|
| 26 |
+
Repo format:
|
| 27 |
+
- runs/YYYY/MM/DD/<job_id>/content.jpg
|
| 28 |
+
- runs/YYYY/MM/DD/<job_id>/style.jpg
|
| 29 |
+
- runs/YYYY/MM/DD/<job_id>/result.jpg
|
| 30 |
+
- gallery/gallery.json
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, dataset_repo: str, hf_token: Optional[str] = None, revision: str = "main"):
|
| 34 |
+
if not dataset_repo:
|
| 35 |
+
raise ValueError("HF_DATASET_REPO is not set. Please configure the dataset repository id.")
|
| 36 |
+
self.dataset_repo = dataset_repo
|
| 37 |
+
self.revision = revision
|
| 38 |
+
self.api = HfApi(token=hf_token) if hf_token else HfApi()
|
| 39 |
+
|
| 40 |
+
def load_gallery(self) -> List[Dict[str, Any]]:
|
| 41 |
+
"""
|
| 42 |
+
Download and parse gallery.json from the dataset. If missing, return [].
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
# Try to get the raw file content via the hub URL
|
| 46 |
+
url = build_dataset_resolve_url(self.dataset_repo, GALLERY_FILE_PATH, self.revision)
|
| 47 |
+
import requests # local import to avoid hard dependency elsewhere
|
| 48 |
+
headers = {}
|
| 49 |
+
if self.api.token:
|
| 50 |
+
headers["Authorization"] = f"Bearer {self.api.token}"
|
| 51 |
+
resp = requests.get(url, timeout=10, headers=headers)
|
| 52 |
+
if resp.status_code == 200:
|
| 53 |
+
return resp.json()
|
| 54 |
+
logger.info("Gallery not found at %s (status %s). Initializing empty gallery.", url, resp.status_code)
|
| 55 |
+
return []
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error("Failed to load gallery from HF: %s", str(e))
|
| 58 |
+
return []
|
| 59 |
+
|
| 60 |
+
def save_gallery(self, gallery: List[Dict[str, Any]]) -> None:
|
| 61 |
+
"""
|
| 62 |
+
Commit a new version of gallery.json to the dataset repo.
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
payload = json.dumps(gallery, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
| 66 |
+
operations = [
|
| 67 |
+
CommitOperationAdd(path_in_repo=GALLERY_FILE_PATH, path_or_fileobj=payload)
|
| 68 |
+
]
|
| 69 |
+
create_commit(
|
| 70 |
+
repo_id=self.dataset_repo,
|
| 71 |
+
repo_type="dataset",
|
| 72 |
+
operations=operations,
|
| 73 |
+
commit_message="Update gallery.json",
|
| 74 |
+
revision=self.revision,
|
| 75 |
+
token=self.api.token,
|
| 76 |
+
)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error("Failed to save gallery to HF: %s", str(e))
|
| 79 |
+
raise
|
| 80 |
+
|
| 81 |
+
def upload_file(self, local_path: str, dst_path: str) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Upload a local file to the dataset repo at dst_path. Returns the path_in_repo.
|
| 84 |
+
"""
|
| 85 |
+
if not os.path.exists(local_path):
|
| 86 |
+
raise FileNotFoundError(local_path)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
with open(local_path, "rb") as f:
|
| 90 |
+
operations = [
|
| 91 |
+
CommitOperationAdd(path_in_repo=dst_path, path_or_fileobj=f)
|
| 92 |
+
]
|
| 93 |
+
create_commit(
|
| 94 |
+
repo_id=self.dataset_repo,
|
| 95 |
+
repo_type="dataset",
|
| 96 |
+
operations=operations,
|
| 97 |
+
commit_message=f"Upload {dst_path}",
|
| 98 |
+
revision=self.revision,
|
| 99 |
+
token=self.api.token,
|
| 100 |
+
)
|
| 101 |
+
return dst_path
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error("Failed to upload %s to HF at %s: %s", local_path, dst_path, str(e))
|
| 104 |
+
raise
|
| 105 |
+
|
| 106 |
+
def delete_run_artifacts(self, gallery_item: Dict[str, Any]) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Attempt to delete the three image artifacts associated with a run.
|
| 109 |
+
This parses resolve URLs to determine paths in repo.
|
| 110 |
+
"""
|
| 111 |
+
def extract_path(url: Optional[str]) -> Optional[str]:
|
| 112 |
+
if not url:
|
| 113 |
+
return None
|
| 114 |
+
marker = "/resolve/"
|
| 115 |
+
if marker in url:
|
| 116 |
+
try:
|
| 117 |
+
# url ends with .../resolve/<rev>/<path_in_repo>
|
| 118 |
+
parts = url.split(marker, 1)[1].split("/", 1)
|
| 119 |
+
if len(parts) == 2:
|
| 120 |
+
return parts[1]
|
| 121 |
+
except Exception:
|
| 122 |
+
return None
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
paths: List[str] = []
|
| 126 |
+
for key in ("contentImageUrl", "styleImageUrl", "resultImageUrl"):
|
| 127 |
+
p = extract_path(gallery_item.get(key))
|
| 128 |
+
if p:
|
| 129 |
+
paths.append(p)
|
| 130 |
+
|
| 131 |
+
if not paths:
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
operations = [CommitOperationDelete(path) for path in paths]
|
| 136 |
+
create_commit(
|
| 137 |
+
repo_id=self.dataset_repo,
|
| 138 |
+
repo_type="dataset",
|
| 139 |
+
operations=operations,
|
| 140 |
+
commit_message=f"Delete artifacts for run {gallery_item.get('id', '')}",
|
| 141 |
+
revision=self.revision,
|
| 142 |
+
token=self.api.token,
|
| 143 |
+
)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error("Failed to delete artifacts %s: %s", paths, str(e))
|
| 146 |
+
|
style_transfer.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import torchvision.models as models
|
| 8 |
+
import copy
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
import io
|
| 12 |
+
|
| 13 |
+
# Check if GPU is available
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
print(f"Using device: {device}")
|
| 16 |
+
|
| 17 |
+
# Image loading and preprocessing
|
| 18 |
+
def image_loader(image_path, imsize=512):
|
| 19 |
+
loader = transforms.Compose([
|
| 20 |
+
transforms.Resize(imsize), # Scale imported image
|
| 21 |
+
transforms.CenterCrop(imsize), # Ensure square size
|
| 22 |
+
transforms.ToTensor(), # Transform into torch tensor
|
| 23 |
+
transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x) # Convert grayscale to RGB if needed
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
image = Image.open(image_path).convert('RGB') # Ensure image is RGB
|
| 27 |
+
# Add batch dimension (1, 3, h, w)
|
| 28 |
+
image = loader(image).unsqueeze(0)
|
| 29 |
+
return image.to(device, torch.float)
|
| 30 |
+
|
| 31 |
+
def load_image_from_bytes(image_bytes, imsize=512):
|
| 32 |
+
loader = transforms.Compose([
|
| 33 |
+
transforms.Resize(imsize),
|
| 34 |
+
transforms.CenterCrop(imsize),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x)
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
| 40 |
+
image = loader(image).unsqueeze(0)
|
| 41 |
+
return image.to(device, torch.float)
|
| 42 |
+
|
| 43 |
+
# Content Loss: Measures content similarity
|
| 44 |
+
class ContentLoss(nn.Module):
|
| 45 |
+
def __init__(self, target):
|
| 46 |
+
super(ContentLoss, self).__init__()
|
| 47 |
+
# Detach the target content from the tree used to dynamically compute gradients
|
| 48 |
+
self.target = target.detach()
|
| 49 |
+
|
| 50 |
+
def forward(self, input):
|
| 51 |
+
self.loss = F.mse_loss(input, self.target)
|
| 52 |
+
return input
|
| 53 |
+
|
| 54 |
+
# Gram matrix calculation for style representation
|
| 55 |
+
def gram_matrix(input):
|
| 56 |
+
batch_size, n_channels, height, width = input.size()
|
| 57 |
+
features = input.view(batch_size * n_channels, height * width)
|
| 58 |
+
G = torch.mm(features, features.t())
|
| 59 |
+
# Normalize by total number of elements
|
| 60 |
+
return G.div(batch_size * n_channels * height * width)
|
| 61 |
+
|
| 62 |
+
# Style Loss: Measures style similarity using Gram matrices
|
| 63 |
+
class StyleLoss(nn.Module):
|
| 64 |
+
def __init__(self, target_feature):
|
| 65 |
+
super(StyleLoss, self).__init__()
|
| 66 |
+
self.target = gram_matrix(target_feature).detach()
|
| 67 |
+
self.weight = 1.0 # Default weight for this layer
|
| 68 |
+
|
| 69 |
+
def forward(self, input):
|
| 70 |
+
G = gram_matrix(input)
|
| 71 |
+
self.loss = F.mse_loss(G, self.target)
|
| 72 |
+
return input
|
| 73 |
+
|
| 74 |
+
# Normalization layer for VGG compatibility
|
| 75 |
+
class Normalization(nn.Module):
|
| 76 |
+
def __init__(self, mean, std):
|
| 77 |
+
super(Normalization, self).__init__()
|
| 78 |
+
# View the mean and std as 1x3x1x1 tensors
|
| 79 |
+
self.mean = mean.clone().detach().view(-1, 1, 1).to(device)
|
| 80 |
+
self.std = std.clone().detach().view(-1, 1, 1).to(device)
|
| 81 |
+
|
| 82 |
+
def forward(self, img):
|
| 83 |
+
# Normalize img
|
| 84 |
+
return (img - self.mean) / self.std
|
| 85 |
+
|
| 86 |
+
# Build model with content and style losses
|
| 87 |
+
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
|
| 88 |
+
style_img, content_img,
|
| 89 |
+
content_layers=['conv_4'],
|
| 90 |
+
style_layers=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'],
|
| 91 |
+
layer_weights=None):
|
| 92 |
+
normalization = Normalization(normalization_mean, normalization_std)
|
| 93 |
+
|
| 94 |
+
# Set default layer weights if not provided
|
| 95 |
+
if layer_weights is None:
|
| 96 |
+
layer_weights = {layer: 1.0 for layer in style_layers}
|
| 97 |
+
|
| 98 |
+
# Lists to keep track of losses
|
| 99 |
+
content_losses = []
|
| 100 |
+
style_losses = []
|
| 101 |
+
|
| 102 |
+
# Create a "sequential" module with added content/style loss layers
|
| 103 |
+
model = nn.Sequential(normalization)
|
| 104 |
+
|
| 105 |
+
i = 0 # Increment for each conv layer
|
| 106 |
+
for layer in cnn.children():
|
| 107 |
+
if isinstance(layer, nn.Conv2d):
|
| 108 |
+
i += 1
|
| 109 |
+
name = f'conv_{i}'
|
| 110 |
+
elif isinstance(layer, nn.ReLU):
|
| 111 |
+
name = f'relu_{i}'
|
| 112 |
+
# Replace in-place version with out-of-place
|
| 113 |
+
layer = nn.ReLU(inplace=False)
|
| 114 |
+
elif isinstance(layer, nn.MaxPool2d):
|
| 115 |
+
name = f'pool_{i}'
|
| 116 |
+
elif isinstance(layer, nn.BatchNorm2d):
|
| 117 |
+
name = f'bn_{i}'
|
| 118 |
+
else:
|
| 119 |
+
raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
|
| 120 |
+
|
| 121 |
+
model.add_module(name, layer)
|
| 122 |
+
|
| 123 |
+
# Add content loss
|
| 124 |
+
if name in content_layers:
|
| 125 |
+
# Add content loss:
|
| 126 |
+
target = model(content_img).detach()
|
| 127 |
+
content_loss = ContentLoss(target)
|
| 128 |
+
model.add_module(f"content_loss_{i}", content_loss)
|
| 129 |
+
content_losses.append(content_loss)
|
| 130 |
+
|
| 131 |
+
# Add style loss
|
| 132 |
+
if name in style_layers:
|
| 133 |
+
# Add style loss:
|
| 134 |
+
target_feature = model(style_img).detach()
|
| 135 |
+
style_loss = StyleLoss(target_feature)
|
| 136 |
+
|
| 137 |
+
# Apply customized layer weight
|
| 138 |
+
style_loss.weight = layer_weights.get(name, 1.0)
|
| 139 |
+
|
| 140 |
+
model.add_module(f"style_loss_{i}", style_loss)
|
| 141 |
+
style_losses.append(style_loss)
|
| 142 |
+
|
| 143 |
+
# Trim off the layers after the last content and style losses
|
| 144 |
+
for i in range(len(model) - 1, -1, -1):
|
| 145 |
+
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
model = model[:(i + 1)]
|
| 149 |
+
|
| 150 |
+
return model, style_losses, content_losses
|
| 151 |
+
|
| 152 |
+
# Optimization loop for style transfer
|
| 153 |
+
def run_style_transfer(cnn, normalization_mean, normalization_std,
|
| 154 |
+
content_img, style_img, input_img, num_steps=300,
|
| 155 |
+
style_weight=1000000, content_weight=1,
|
| 156 |
+
layer_weights=None, progress_callback=None):
|
| 157 |
+
"""Run the style transfer."""
|
| 158 |
+
num_steps = min(num_steps, 400)
|
| 159 |
+
|
| 160 |
+
print('Building the style transfer model...')
|
| 161 |
+
model, style_losses, content_losses = get_style_model_and_losses(
|
| 162 |
+
cnn, normalization_mean, normalization_std,
|
| 163 |
+
style_img, content_img,
|
| 164 |
+
layer_weights=layer_weights
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# We want to optimize the input image only
|
| 168 |
+
input_img.requires_grad_(True)
|
| 169 |
+
model.eval() # We don't need gradients for the model parameters
|
| 170 |
+
model.requires_grad_(False)
|
| 171 |
+
|
| 172 |
+
optimizer = optim.LBFGS([input_img])
|
| 173 |
+
best_img = None
|
| 174 |
+
best_loss = float('inf')
|
| 175 |
+
prev_loss = float('inf')
|
| 176 |
+
current_step = 0
|
| 177 |
+
|
| 178 |
+
start_time = time.time()
|
| 179 |
+
|
| 180 |
+
# Function to be used with optimizer
|
| 181 |
+
def closure():
|
| 182 |
+
nonlocal current_step
|
| 183 |
+
# Correct the values of updated input image
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
input_img.clamp_(0, 1)
|
| 186 |
+
|
| 187 |
+
optimizer.zero_grad()
|
| 188 |
+
model(input_img)
|
| 189 |
+
style_score = 0
|
| 190 |
+
content_score = 0
|
| 191 |
+
|
| 192 |
+
for sl in style_losses:
|
| 193 |
+
# Apply per-layer weight
|
| 194 |
+
style_score += sl.loss * sl.weight
|
| 195 |
+
|
| 196 |
+
for cl in content_losses:
|
| 197 |
+
content_score += cl.loss
|
| 198 |
+
|
| 199 |
+
style_score *= style_weight
|
| 200 |
+
content_score *= content_weight
|
| 201 |
+
|
| 202 |
+
loss = style_score + content_score
|
| 203 |
+
loss.backward()
|
| 204 |
+
|
| 205 |
+
current_step += 1
|
| 206 |
+
if current_step % 50 == 0:
|
| 207 |
+
elapsed = time.time() - start_time
|
| 208 |
+
print(f"Iteration: {current_step}, Style Loss: {style_score.item():.2f}, Content Loss: {content_score.item():.2f}, Total Loss: {loss.item():.2f}, Time: {elapsed:.1f}s")
|
| 209 |
+
|
| 210 |
+
if progress_callback:
|
| 211 |
+
progress = {
|
| 212 |
+
'iteration': current_step,
|
| 213 |
+
'style_loss': style_score.item(),
|
| 214 |
+
'content_loss': content_score.item(),
|
| 215 |
+
'elapsed_time': elapsed
|
| 216 |
+
}
|
| 217 |
+
progress_callback(progress)
|
| 218 |
+
|
| 219 |
+
# Save best result so far
|
| 220 |
+
nonlocal best_loss, best_img, prev_loss
|
| 221 |
+
current_loss = loss.item()
|
| 222 |
+
|
| 223 |
+
if current_loss < best_loss:
|
| 224 |
+
best_loss = current_loss
|
| 225 |
+
best_img = input_img.clone()
|
| 226 |
+
|
| 227 |
+
# Update previous loss for next iteration
|
| 228 |
+
prev_loss = current_loss
|
| 229 |
+
return loss
|
| 230 |
+
|
| 231 |
+
# Run optimization with early stopping
|
| 232 |
+
while current_step < num_steps:
|
| 233 |
+
optimizer.step(closure)
|
| 234 |
+
|
| 235 |
+
# Check stopping conditions after minimum iterations
|
| 236 |
+
if current_step >= 50 and prev_loss > 1000:
|
| 237 |
+
print(f"Stopping early at iteration {current_step} due to high loss: {prev_loss:.2f}")
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
# A final correction
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
input_img.clamp_(0, 1)
|
| 243 |
+
|
| 244 |
+
print(f"Total time: {time.time() - start_time:.1f}s")
|
| 245 |
+
print(f"Best loss achieved: {best_loss:.2f}")
|
| 246 |
+
|
| 247 |
+
# Return both the final and best image (often the same)
|
| 248 |
+
return input_img, best_img, best_loss
|
| 249 |
+
|
| 250 |
+
# Save tensor as image
|
| 251 |
+
def save_image(tensor, path):
|
| 252 |
+
image = tensor.cpu().clone()
|
| 253 |
+
image = image.squeeze(0) # Remove batch dimension
|
| 254 |
+
image = transforms.ToPILImage()(image)
|
| 255 |
+
image.save(path)
|
| 256 |
+
return image
|
| 257 |
+
|
| 258 |
+
# Main style transfer function
|
| 259 |
+
def transfer_style(content_path, style_path, output_path, style_weight=1000000,
|
| 260 |
+
content_weight=1, num_steps=300, layer_weights=None,
|
| 261 |
+
progress_callback=None):
|
| 262 |
+
"""
|
| 263 |
+
Perform style transfer and save the result
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
content_path: Path to content image
|
| 267 |
+
style_path: Path to style image
|
| 268 |
+
output_path: Where to save the output image
|
| 269 |
+
style_weight: Weight for style loss
|
| 270 |
+
content_weight: Weight for content loss
|
| 271 |
+
num_steps: Number of optimization steps
|
| 272 |
+
layer_weights: Dictionary of weights for each style layer
|
| 273 |
+
progress_callback: Function to call for progress updates
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Tuple of (output_path, best_loss)
|
| 277 |
+
"""
|
| 278 |
+
# Load images
|
| 279 |
+
content_img = image_loader(content_path)
|
| 280 |
+
style_img = image_loader(style_path)
|
| 281 |
+
|
| 282 |
+
# Start with content image for faster convergence
|
| 283 |
+
input_img = content_img.clone()
|
| 284 |
+
|
| 285 |
+
# Load VGG19 for feature extraction
|
| 286 |
+
cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(device).eval()
|
| 287 |
+
|
| 288 |
+
# Mean and std for normalization (from ImageNet)
|
| 289 |
+
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
| 290 |
+
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
| 291 |
+
|
| 292 |
+
# Run style transfer
|
| 293 |
+
output, best_output, best_loss = run_style_transfer(
|
| 294 |
+
cnn,
|
| 295 |
+
cnn_normalization_mean,
|
| 296 |
+
cnn_normalization_std,
|
| 297 |
+
content_img,
|
| 298 |
+
style_img,
|
| 299 |
+
input_img,
|
| 300 |
+
num_steps=num_steps,
|
| 301 |
+
style_weight=style_weight,
|
| 302 |
+
content_weight=content_weight,
|
| 303 |
+
layer_weights=layer_weights,
|
| 304 |
+
progress_callback=progress_callback
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Save result and return path
|
| 308 |
+
save_image(best_output, output_path)
|
| 309 |
+
|
| 310 |
+
return output_path, best_loss
|