Adisri99 commited on
Commit
c4a0359
·
verified ·
1 Parent(s): d5b9020

Upload 12 files

Browse files
Files changed (12) hide show
  1. Dockerfile +26 -0
  2. README.md +45 -5
  3. app.py +9 -0
  4. auth.py +97 -0
  5. auth_storage.py +201 -0
  6. email_service.py +121 -0
  7. env.example +30 -0
  8. main.py +722 -0
  9. requirements.txt +14 -0
  10. run.py +10 -0
  11. storage_hf.py +146 -0
  12. style_transfer.py +310 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ PORT=7860
7
+
8
+ WORKDIR /app
9
+
10
+ # Install git first (needed for Hugging Face Spaces build process)
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # System deps (optional, pillow usually fine without extra)
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ build-essential \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ COPY requirements.txt ./
21
+ RUN pip install --upgrade pip && pip install -r requirements.txt
22
+
23
+ COPY . .
24
+
25
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860} --proxy-headers --forwarded-allow-ips=*"]
26
+
README.md CHANGED
@@ -1,10 +1,50 @@
1
  ---
2
- title: StyleExplorer
3
- emoji: 🏆
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: docker
 
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Neural Style Transfer API
3
+ emoji: 🎨
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ license: mit
10
  ---
11
 
12
+ # Neural Style Transfer API
13
+
14
+ FastAPI backend for neural style transfer using PyTorch and VGG19.
15
+
16
+ ## Features
17
+
18
+ - Style transfer processing with customizable parameters
19
+ - Gallery management for generated images
20
+ - User authentication and permission requests
21
+ - Image storage via Hugging Face datasets
22
+
23
+ ## Environment Variables
24
+
25
+ Set these in the Hugging Face Space secrets:
26
+
27
+ - `HF_DATASET_REPO`: Your Hugging Face dataset repository (e.g., `username/style-transfer-data`)
28
+ - `HF_TOKEN`: Hugging Face access token with write permissions
29
+ - `MASTER_PASSWORD`: Master password for admin access
30
+ - `ADMIN_EMAIL`: Admin email for receiving permission request notifications
31
+ - `ALLOWED_ORIGINS`: Comma-separated list of allowed CORS origins
32
+ - `SMTP_HOST`: SMTP server host (optional, for email notifications)
33
+ - `SMTP_PORT`: SMTP server port (optional)
34
+ - `SMTP_USER`: SMTP username (optional)
35
+ - `SMTP_PASSWORD`: SMTP password (optional)
36
+ - `SMTP_FROM_EMAIL`: Email address to send from (optional)
37
+
38
+ ## API Endpoints
39
+
40
+ - `GET /api/health` - Health check
41
+ - `POST /api/transfer` - Create style transfer job (requires auth)
42
+ - `GET /api/transfer/{job_id}` - Get job status
43
+ - `GET /api/gallery` - List gallery items
44
+ - `GET /api/gallery/{item_id}` - Get gallery item
45
+ - `DELETE /api/gallery/{item_id}` - Delete gallery item (requires auth)
46
+ - `POST /api/auth/login` - Login
47
+ - `POST /api/auth/requests` - Submit permission request
48
+ - `GET /api/auth/requests` - List requests (admin only)
49
+ - `POST /api/auth/requests/{id}/approve` - Approve request (admin only)
50
+ - `POST /api/auth/requests/{id}/reject` - Reject request (admin only)
app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for Hugging Face Spaces deployment.
3
+ This file simply imports and exposes the FastAPI app from main.py.
4
+ """
5
+ from main import app
6
+
7
+ # The app is already configured in main.py
8
+ # Hugging Face Spaces will automatically detect and serve this FastAPI app
9
+ #filler
auth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from jose import jwt
3
+ from jose.exceptions import ExpiredSignatureError, JWTError
4
+ from datetime import datetime, timedelta
5
+ from typing import Optional, Dict, Any
6
+ from fastapi import HTTPException, status, Depends
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+ from passlib.context import CryptContext
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # JWT configuration
14
+ SECRET_KEY = os.getenv("JWT_SECRET_KEY", os.getenv("MASTER_PASSWORD", "change-me-in-production"))
15
+ ALGORITHM = "HS256"
16
+ ACCESS_TOKEN_EXPIRE_HOURS = 24 * 7 # 7 days
17
+
18
+ # Password hashing
19
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
20
+
21
+ # HTTP Bearer token scheme
22
+ security = HTTPBearer()
23
+
24
+
25
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
26
+ """Verify a password against a hash."""
27
+ return pwd_context.verify(plain_password, hashed_password)
28
+
29
+
30
+ def get_password_hash(password: str) -> str:
31
+ """Hash a password."""
32
+ return pwd_context.hash(password)
33
+
34
+
35
+ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
36
+ """Create a JWT access token."""
37
+ to_encode = data.copy()
38
+ if expires_delta:
39
+ expire = datetime.utcnow() + expires_delta
40
+ else:
41
+ expire = datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
42
+ to_encode.update({"exp": expire})
43
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
44
+ return encoded_jwt
45
+
46
+
47
+ def verify_token(token: str) -> Optional[Dict[str, Any]]:
48
+ """Verify and decode a JWT token."""
49
+ try:
50
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
51
+ return payload
52
+ except ExpiredSignatureError:
53
+ logger.warning("Token has expired")
54
+ return None
55
+ except JWTError:
56
+ logger.warning("Invalid token")
57
+ return None
58
+
59
+
60
+ async def get_current_user(
61
+ credentials: HTTPAuthorizationCredentials = Depends(security)
62
+ ) -> Dict[str, Any]:
63
+ """
64
+ Dependency to get the current authenticated user from JWT token.
65
+ """
66
+ token = credentials.credentials
67
+ payload = verify_token(token)
68
+ if payload is None:
69
+ raise HTTPException(
70
+ status_code=status.HTTP_401_UNAUTHORIZED,
71
+ detail="Invalid authentication credentials",
72
+ headers={"WWW-Authenticate": "Bearer"},
73
+ )
74
+ return payload
75
+
76
+
77
+ async def get_current_user_optional(
78
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False))
79
+ ) -> Optional[Dict[str, Any]]:
80
+ """
81
+ Dependency to get the current user if authenticated, None otherwise.
82
+ """
83
+ if credentials is None:
84
+ return None
85
+ token = credentials.credentials
86
+ payload = verify_token(token)
87
+ return payload
88
+
89
+
90
+ def require_auth(func):
91
+ """
92
+ Decorator to require authentication for an endpoint.
93
+ """
94
+ async def wrapper(*args, **kwargs):
95
+ # This will be handled by the Depends(get_current_user) in the route
96
+ return await func(*args, **kwargs)
97
+ return wrapper
auth_storage.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+ from datetime import datetime
6
+ import uuid
7
+
8
+ from huggingface_hub import HfApi, CommitOperationAdd, create_commit, hf_hub_url
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ USERS_FILE_PATH = "auth/users.json"
13
+ REQUESTS_FILE_PATH = "auth/requests.json"
14
+
15
+
16
+ def build_dataset_resolve_url(repo_id: str, path_in_repo: str, revision: str = "main") -> str:
17
+ """
18
+ Build a CDN-resolved URL for a file stored in a Hugging Face dataset repo.
19
+ """
20
+ return hf_hub_url(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", revision=revision)
21
+
22
+
23
+ class AuthStorageClient:
24
+ """
25
+ Helper for managing user authentication data and permission requests
26
+ in a Hugging Face dataset repository.
27
+
28
+ Repo format:
29
+ - auth/users.json
30
+ - auth/requests.json
31
+ """
32
+
33
+ def __init__(self, dataset_repo: str, hf_token: Optional[str] = None, revision: str = "main"):
34
+ if not dataset_repo:
35
+ raise ValueError("HF_DATASET_REPO is not set. Please configure the dataset repository id.")
36
+ self.dataset_repo = dataset_repo
37
+ self.revision = revision
38
+ self.api = HfApi(token=hf_token) if hf_token else HfApi()
39
+
40
+ def load_users(self) -> List[Dict[str, Any]]:
41
+ """
42
+ Download and parse users.json from the dataset. If missing, return [].
43
+ """
44
+ try:
45
+ url = build_dataset_resolve_url(self.dataset_repo, USERS_FILE_PATH, self.revision)
46
+ import requests
47
+ headers = {}
48
+ if self.api.token:
49
+ headers["Authorization"] = f"Bearer {self.api.token}"
50
+ resp = requests.get(url, timeout=10, headers=headers)
51
+ if resp.status_code == 200:
52
+ data = resp.json()
53
+ return data.get("users", [])
54
+ logger.info("Users file not found at %s (status %s). Initializing empty users list.", url, resp.status_code)
55
+ return []
56
+ except Exception as e:
57
+ logger.error("Failed to load users from HF: %s", str(e))
58
+ return []
59
+
60
+ def save_users(self, users: List[Dict[str, Any]]) -> None:
61
+ """
62
+ Commit a new version of users.json to the dataset repo.
63
+ """
64
+ try:
65
+ payload = json.dumps({"users": users}, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
66
+ operations = [
67
+ CommitOperationAdd(path_in_repo=USERS_FILE_PATH, path_or_fileobj=payload)
68
+ ]
69
+ create_commit(
70
+ repo_id=self.dataset_repo,
71
+ repo_type="dataset",
72
+ operations=operations,
73
+ commit_message="Update users.json",
74
+ revision=self.revision,
75
+ token=self.api.token,
76
+ )
77
+ except Exception as e:
78
+ logger.error("Failed to save users to HF: %s", str(e))
79
+ raise
80
+
81
+ def load_requests(self) -> List[Dict[str, Any]]:
82
+ """
83
+ Download and parse requests.json from the dataset. If missing, return [].
84
+ """
85
+ try:
86
+ url = build_dataset_resolve_url(self.dataset_repo, REQUESTS_FILE_PATH, self.revision)
87
+ import requests
88
+ headers = {}
89
+ if self.api.token:
90
+ headers["Authorization"] = f"Bearer {self.api.token}"
91
+ resp = requests.get(url, timeout=10, headers=headers)
92
+ if resp.status_code == 200:
93
+ data = resp.json()
94
+ return data.get("requests", [])
95
+ logger.info("Requests file not found at %s (status %s). Initializing empty requests list.", url, resp.status_code)
96
+ return []
97
+ except Exception as e:
98
+ logger.error("Failed to load requests from HF: %s", str(e))
99
+ return []
100
+
101
+ def save_requests(self, requests: List[Dict[str, Any]]) -> None:
102
+ """
103
+ Commit a new version of requests.json to the dataset repo.
104
+ """
105
+ try:
106
+ payload = json.dumps({"requests": requests}, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
107
+ operations = [
108
+ CommitOperationAdd(path_in_repo=REQUESTS_FILE_PATH, path_or_fileobj=payload)
109
+ ]
110
+ create_commit(
111
+ repo_id=self.dataset_repo,
112
+ repo_type="dataset",
113
+ operations=operations,
114
+ commit_message="Update requests.json",
115
+ revision=self.revision,
116
+ token=self.api.token,
117
+ )
118
+ except Exception as e:
119
+ logger.error("Failed to save requests to HF: %s", str(e))
120
+ raise
121
+
122
+ def add_user(self, email: str, password_hash: str) -> None:
123
+ """
124
+ Add a new user to the users list.
125
+ """
126
+ users = self.load_users()
127
+ # Check if user already exists
128
+ if any(user.get("email") == email for user in users):
129
+ raise ValueError(f"User with email {email} already exists")
130
+
131
+ users.append({
132
+ "email": email,
133
+ "password_hash": password_hash
134
+ })
135
+ self.save_users(users)
136
+
137
+ def get_user(self, email: str) -> Optional[Dict[str, Any]]:
138
+ """
139
+ Get a user by email.
140
+ """
141
+ users = self.load_users()
142
+ return next((user for user in users if user.get("email") == email), None)
143
+
144
+ def delete_user(self, email: str) -> None:
145
+ """
146
+ Delete a user by email.
147
+ """
148
+ users = self.load_users()
149
+ users = [user for user in users if user.get("email") != email]
150
+ self.save_users(users)
151
+
152
+ def add_request(self, name: str, email: str, reason: str) -> str:
153
+ """
154
+ Add a new permission request. Returns the request ID.
155
+ """
156
+ requests = self.load_requests()
157
+ request_id = str(uuid.uuid4())
158
+
159
+ new_request = {
160
+ "id": request_id,
161
+ "name": name,
162
+ "email": email,
163
+ "reason": reason,
164
+ "timestamp": datetime.utcnow().isoformat(),
165
+ "status": "pending",
166
+ "reviewed_at": None,
167
+ "rejection_reason": None
168
+ }
169
+
170
+ requests.append(new_request)
171
+ self.save_requests(requests)
172
+ return request_id
173
+
174
+ def get_request(self, request_id: str) -> Optional[Dict[str, Any]]:
175
+ """
176
+ Get a request by ID.
177
+ """
178
+ requests = self.load_requests()
179
+ return next((req for req in requests if req.get("id") == request_id), None)
180
+
181
+ def update_request_status(self, request_id: str, status: str, rejection_reason: Optional[str] = None) -> None:
182
+ """
183
+ Update the status of a request.
184
+ """
185
+ requests = self.load_requests()
186
+ for req in requests:
187
+ if req.get("id") == request_id:
188
+ req["status"] = status
189
+ req["reviewed_at"] = datetime.utcnow().isoformat()
190
+ if rejection_reason:
191
+ req["rejection_reason"] = rejection_reason
192
+ break
193
+ self.save_requests(requests)
194
+
195
+ def delete_request(self, request_id: str) -> None:
196
+ """
197
+ Delete a request by ID.
198
+ """
199
+ requests = self.load_requests()
200
+ requests = [req for req in requests if req.get("id") != request_id]
201
+ self.save_requests(requests)
email_service.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import smtplib
3
+ import logging
4
+ from email.mime.text import MIMEText
5
+ from email.mime.multipart import MIMEMultipart
6
+ from typing import Optional
7
+ from datetime import datetime
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class EmailService:
13
+ """
14
+ Service for sending emails via SMTP.
15
+ Supports Gmail and other SMTP servers.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.smtp_host = os.getenv("SMTP_HOST", "smtp.gmail.com")
20
+ self.smtp_port = int(os.getenv("SMTP_PORT", "587"))
21
+ self.smtp_user = os.getenv("SMTP_USER", "")
22
+ self.smtp_password = os.getenv("SMTP_PASSWORD", "")
23
+ self.from_email = os.getenv("SMTP_FROM_EMAIL", self.smtp_user)
24
+ self.admin_email = os.getenv("ADMIN_EMAIL", "")
25
+
26
+ def _send_email(self, to_email: str, subject: str, body: str, is_html: bool = False) -> bool:
27
+ """
28
+ Send an email via SMTP.
29
+ """
30
+ if not self.smtp_user or not self.smtp_password:
31
+ logger.warning("SMTP credentials not configured. Email not sent.")
32
+ return False
33
+
34
+ try:
35
+ msg = MIMEMultipart()
36
+ msg['From'] = self.from_email
37
+ msg['To'] = to_email
38
+ msg['Subject'] = subject
39
+
40
+ if is_html:
41
+ msg.attach(MIMEText(body, 'html'))
42
+ else:
43
+ msg.attach(MIMEText(body, 'plain'))
44
+
45
+ with smtplib.SMTP(self.smtp_host, self.smtp_port) as server:
46
+ server.starttls()
47
+ server.login(self.smtp_user, self.smtp_password)
48
+ server.send_message(msg)
49
+
50
+ logger.info(f"Email sent successfully to {to_email}")
51
+ return True
52
+ except Exception as e:
53
+ logger.error(f"Failed to send email to {to_email}: {str(e)}")
54
+ return False
55
+
56
+ def send_permission_request_notification(self, name: str, email: str, reason: str, timestamp: str) -> bool:
57
+ """
58
+ Send email notification to admin when a permission request is submitted.
59
+ """
60
+ if not self.admin_email:
61
+ logger.warning("ADMIN_EMAIL not configured. Notification not sent.")
62
+ return False
63
+
64
+ subject = f"New Permission Request: {name}"
65
+ body = f"""
66
+ A new permission request has been submitted:
67
+
68
+ Name: {name}
69
+ Email: {email}
70
+ Reason: {reason}
71
+ Timestamp: {timestamp}
72
+
73
+ Please review the request in the admin interface.
74
+ """
75
+ return self._send_email(self.admin_email, subject, body)
76
+
77
+ def send_approval_email(self, user_email: str, user_name: str, password: str) -> bool:
78
+ """
79
+ Send approval email to user with their account credentials.
80
+ """
81
+ subject = "Your Style Transfer Account Has Been Approved"
82
+ body = f"""
83
+ Hello {user_name},
84
+
85
+ Your permission request has been approved! Your account has been created.
86
+
87
+ Login Credentials:
88
+ Email: {user_email}
89
+ Password: {password}
90
+
91
+ You can now access the Neural Style Transfer application and create style transfers.
92
+
93
+ Please keep your password secure and do not share it with anyone.
94
+
95
+ Best regards,
96
+ Style Transfer Team
97
+ """
98
+ return self._send_email(user_email, subject, body)
99
+
100
+ def send_rejection_email(self, user_email: str, user_name: str, reason: Optional[str] = None) -> bool:
101
+ """
102
+ Send rejection email to user.
103
+ """
104
+ subject = "Permission Request Status Update"
105
+ body = f"""
106
+ Hello {user_name},
107
+
108
+ Thank you for your interest in the Neural Style Transfer application.
109
+
110
+ Unfortunately, your permission request has been declined at this time.
111
+ """
112
+ if reason:
113
+ body += f"\nReason: {reason}\n"
114
+
115
+ body += """
116
+ If you have any questions, please feel free to reach out.
117
+
118
+ Best regards,
119
+ Style Transfer Team
120
+ """
121
+ return self._send_email(user_email, subject, body)
env.example ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy this file to `.env` in the same directory and fill in your values.
2
+
3
+ # Hugging Face Dataset repository to store uploads/results/gallery
4
+ # Format: <username>/<dataset-name>
5
+ HF_DATASET_REPO=your-username/style-transfer-data
6
+
7
+ # Hugging Face access token with write permission to the dataset
8
+ # Create at: https://huggingface.co/settings/tokens
9
+ HF_TOKEN=hf_your_secret_token
10
+
11
+ # Comma-separated list of allowed origins for CORS
12
+ # Include your local dev, GitHub Pages, and deployed frontend URLs
13
+ ALLOWED_ORIGINS=http://localhost:4200,https://your-username.github.io
14
+
15
+ # Master password for admin access (required)
16
+ MASTER_PASSWORD=your-master-password-here
17
+
18
+ # Admin email for receiving permission request notifications (required)
19
+ ADMIN_EMAIL=admin@example.com
20
+
21
+ # Email service configuration (for sending notifications)
22
+ # SMTP settings (for Gmail or other SMTP servers)
23
+ SMTP_HOST=smtp.gmail.com
24
+ SMTP_PORT=587
25
+ SMTP_USER=your-email@gmail.com
26
+ SMTP_PASSWORD=your-app-password
27
+ SMTP_FROM_EMAIL=your-email@gmail.com
28
+
29
+ # Optional: JWT secret key (defaults to MASTER_PASSWORD if not set)
30
+ # JWT_SECRET_KEY=your-jwt-secret-key
main.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+ import secrets
5
+ from fastapi import FastAPI, UploadFile, File, Form, BackgroundTasks, Depends, HTTPException, status
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ import logging
9
+ from pydantic import BaseModel
10
+ from typing import Dict, Optional, Any
11
+ import asyncio
12
+ from datetime import datetime
13
+ import time
14
+ import math
15
+ import tempfile
16
+
17
+ from style_transfer import transfer_style
18
+ from storage_hf import (
19
+ HFStorageClient,
20
+ build_dataset_resolve_url
21
+ )
22
+ from auth import (
23
+ verify_password,
24
+ get_password_hash,
25
+ create_access_token,
26
+ verify_token,
27
+ get_current_user
28
+ )
29
+ from auth_storage import AuthStorageClient
30
+ from email_service import EmailService
31
+
32
+ # Set up logging
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Custom JSON encoder to handle infinity
37
+ class CustomJSONEncoder(json.JSONEncoder):
38
+ def default(self, obj):
39
+ if isinstance(obj, float):
40
+ if math.isinf(obj):
41
+ return "Infinity" if obj > 0 else "-Infinity"
42
+ if math.isnan(obj):
43
+ return "NaN"
44
+ return super().default(obj)
45
+
46
+ # Custom JSONResponse that uses our custom encoder
47
+ class CustomJSONResponse(JSONResponse):
48
+ def render(self, content: Any) -> bytes:
49
+ return json.dumps(
50
+ content,
51
+ ensure_ascii=False,
52
+ allow_nan=False,
53
+ indent=None,
54
+ separators=(",", ":"),
55
+ cls=CustomJSONEncoder,
56
+ ).encode("utf-8")
57
+
58
+ # Initialize FastAPI
59
+ app = FastAPI(title="Neural Style Transfer API", default_response_class=CustomJSONResponse)
60
+
61
+ # Add request logging middleware
62
+ @app.middleware("http")
63
+ async def log_requests(request, call_next):
64
+ logger.info(f"Incoming request: {request.method} {request.url.path} from {request.client.host if request.client else 'unknown'}")
65
+ response = await call_next(request)
66
+ logger.info(f"Response: {request.method} {request.url.path} -> {response.status_code}")
67
+ return response
68
+
69
+ # HF storage configuration
70
+ HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "") # e.g. username/style-transfer-data
71
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
72
+ storage_client = HFStorageClient(dataset_repo=HF_DATASET_REPO, hf_token=HF_TOKEN)
73
+
74
+ # Auth storage and email service
75
+ auth_storage = AuthStorageClient(dataset_repo=HF_DATASET_REPO, hf_token=HF_TOKEN)
76
+ email_service = EmailService()
77
+ MASTER_PASSWORD = os.getenv("MASTER_PASSWORD", "")
78
+
79
+ # Setup CORS
80
+ allowed_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:4200").split(",")
81
+ allowed_origins_clean = [origin.strip() for origin in allowed_origins if origin.strip()]
82
+ logger.info(f"CORS allowed origins: {allowed_origins_clean}")
83
+ app.add_middleware(
84
+ CORSMiddleware,
85
+ allow_origins=allowed_origins_clean,
86
+ allow_credentials=True,
87
+ allow_methods=["*"],
88
+ allow_headers=["*"],
89
+ )
90
+
91
+ """
92
+ With Hugging Face storage, images are served via absolute URLs from the
93
+ dataset CDN, so we no longer mount local static directories.
94
+ """
95
+
96
+ # Keep track of running jobs using asyncio.Queue for thread-safe updates
97
+ job_queues = {}
98
+ active_jobs = {}
99
+
100
+ class StyleTransferProgress(BaseModel):
101
+ job_id: str
102
+ status: str
103
+ progress: Optional[int] = 0
104
+ style_loss: Optional[float] = None
105
+ content_loss: Optional[float] = None
106
+ result_url: Optional[str] = None
107
+ error: Optional[str] = None
108
+
109
+ # Auth models
110
+ class LoginRequest(BaseModel):
111
+ email: Optional[str] = None
112
+ password: Optional[str] = None
113
+ master_password: Optional[str] = None
114
+
115
+ class PermissionRequest(BaseModel):
116
+ name: str
117
+ email: str
118
+ reason: str
119
+
120
+ class CreateUserRequest(BaseModel):
121
+ email: str
122
+ password: str
123
+
124
+ class ApproveRequestData(BaseModel):
125
+ password: Optional[str] = None # Optional custom password, otherwise auto-generated
126
+
127
+ class RejectRequestData(BaseModel):
128
+ reason: Optional[str] = None
129
+
130
+ # Helper function to generate unique file paths
131
+ def get_unique_filename(directory, extension=".jpg"):
132
+ return os.path.join(directory, f"{uuid.uuid4()}{extension}")
133
+
134
+ def load_gallery():
135
+ try:
136
+ return storage_client.load_gallery()
137
+ except Exception as e:
138
+ logger.error(f"Error loading gallery data from HF: {str(e)}")
139
+ return []
140
+
141
+ def save_gallery(gallery_data):
142
+ try:
143
+ storage_client.save_gallery(gallery_data)
144
+ except Exception as e:
145
+ logger.error(f"Error saving gallery data to HF: {str(e)}")
146
+
147
+ # Background task for style transfer
148
+ async def run_style_transfer_task(
149
+ job_id: str,
150
+ content_local_path: str,
151
+ style_local_path: str,
152
+ output_local_path: str,
153
+ style_weight: float,
154
+ content_weight: float,
155
+ num_steps: int,
156
+ layer_weights: Dict[str, float]
157
+ ):
158
+ try:
159
+ # Create a queue for this job if it doesn't exist
160
+ if job_id not in job_queues:
161
+ job_queues[job_id] = asyncio.Queue()
162
+
163
+ queue = job_queues[job_id]
164
+ start_time = time.time()
165
+ best_loss = float('inf')
166
+ style_loss = 0
167
+ content_loss = 0
168
+
169
+ # Update job status
170
+ await queue.put({
171
+ "status": "processing",
172
+ "progress": 0,
173
+ "style_loss": None,
174
+ "content_loss": None
175
+ })
176
+
177
+ # Define a callback that will update the job status
178
+ def on_progress(progress):
179
+ nonlocal style_loss, content_loss, best_loss
180
+ # Calculate total loss as the sum of style and content loss
181
+ total_loss = progress["style_loss"] + progress["content_loss"]
182
+
183
+ # Update the best loss if this one is better
184
+ if total_loss < best_loss:
185
+ best_loss = total_loss
186
+
187
+ progress_data = {
188
+ "status": "processing",
189
+ "progress": progress["iteration"] / num_steps * 100,
190
+ "style_loss": progress["style_loss"],
191
+ "content_loss": progress["content_loss"]
192
+ }
193
+ style_loss = progress["style_loss"]
194
+ content_loss = progress["content_loss"]
195
+
196
+ # Use asyncio.run_coroutine_threadsafe to safely put data in the queue from a different thread
197
+ loop = asyncio.get_event_loop()
198
+ asyncio.run_coroutine_threadsafe(queue.put(progress_data), loop)
199
+
200
+ # Run the style transfer
201
+ result_path, model_best_loss = transfer_style(
202
+ content_path=content_local_path,
203
+ style_path=style_local_path,
204
+ output_path=output_local_path,
205
+ style_weight=style_weight,
206
+ content_weight=content_weight,
207
+ num_steps=num_steps,
208
+ layer_weights=layer_weights,
209
+ progress_callback=on_progress
210
+ )
211
+
212
+ processing_time = time.time() - start_time
213
+
214
+ # If best_loss is still infinity, use the model's best loss or the sum of final losses
215
+ if math.isinf(best_loss):
216
+ best_loss = model_best_loss if not math.isinf(model_best_loss) else style_loss + content_loss
217
+
218
+ # Upload artifacts to Hugging Face dataset and build absolute URLs
219
+ date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
220
+ base_prefix = f"runs/{date_prefix}/{job_id}"
221
+
222
+ content_ds_path = storage_client.upload_file(
223
+ local_path=content_local_path,
224
+ dst_path=f"{base_prefix}/content.jpg"
225
+ )
226
+ style_ds_path = storage_client.upload_file(
227
+ local_path=style_local_path,
228
+ dst_path=f"{base_prefix}/style.jpg"
229
+ )
230
+ result_ds_path = storage_client.upload_file(
231
+ local_path=output_local_path,
232
+ dst_path=f"{base_prefix}/result.jpg"
233
+ )
234
+
235
+ content_url = build_dataset_resolve_url(storage_client.dataset_repo, content_ds_path)
236
+ style_url = build_dataset_resolve_url(storage_client.dataset_repo, style_ds_path)
237
+ result_url = build_dataset_resolve_url(storage_client.dataset_repo, result_ds_path)
238
+
239
+ # Save to gallery
240
+ gallery_item = {
241
+ "id": job_id,
242
+ "timestamp": datetime.utcnow().isoformat(),
243
+ "contentImageUrl": content_url,
244
+ "styleImageUrl": style_url,
245
+ "resultImageUrl": result_url,
246
+ "bestLoss": style_loss + content_loss,
247
+ "styleLoss": style_loss,
248
+ "contentLoss": content_loss,
249
+ "processingTime": processing_time,
250
+ "parameters": {
251
+ "styleWeight": style_weight,
252
+ "contentWeight": content_weight,
253
+ "numSteps": num_steps,
254
+ "layerWeights": layer_weights
255
+ }
256
+ }
257
+
258
+ gallery = load_gallery()
259
+ gallery.append(gallery_item)
260
+ save_gallery(gallery)
261
+
262
+ # Update job status with result
263
+ await queue.put({
264
+ "status": "completed",
265
+ "progress": 100,
266
+ "style_loss": style_loss,
267
+ "content_loss": content_loss,
268
+ "result_url": result_url
269
+ })
270
+
271
+ except Exception as e:
272
+ logger.error(f"Error in style transfer: {str(e)}")
273
+ await queue.put({
274
+ "status": "failed",
275
+ "error": str(e)
276
+ })
277
+ finally:
278
+ # Keep the last status update in active_jobs
279
+ try:
280
+ last_status = queue.get_nowait()
281
+ active_jobs[job_id] = last_status
282
+ except asyncio.QueueEmpty:
283
+ pass
284
+
285
+ @app.post("/api/transfer")
286
+ async def create_style_transfer(
287
+ background_tasks: BackgroundTasks,
288
+ content_image: UploadFile = File(...),
289
+ style_image: UploadFile = File(...),
290
+ style_weight: float = Form(1000000.0),
291
+ content_weight: float = Form(1.0),
292
+ num_steps: int = Form(300),
293
+ layer_weights: str = Form("{}"),
294
+ current_user: Dict[str, Any] = Depends(get_current_user),
295
+ ):
296
+ try:
297
+ # Parse layer weights from JSON string
298
+ layer_weights_dict = json.loads(layer_weights)
299
+
300
+ # Save uploaded files temporarily to local disk for processing
301
+ temp_dir = tempfile.gettempdir()
302
+ content_path = get_unique_filename(temp_dir)
303
+ style_path = get_unique_filename(temp_dir)
304
+ output_path = get_unique_filename(temp_dir)
305
+
306
+ with open(content_path, "wb") as content_file:
307
+ content_file.write(await content_image.read())
308
+
309
+ with open(style_path, "wb") as style_file:
310
+ style_file.write(await style_image.read())
311
+
312
+ # Create a unique job ID
313
+ job_id = str(uuid.uuid4())
314
+
315
+ # Initialize job status
316
+ active_jobs[job_id] = {
317
+ "status": "pending",
318
+ "progress": 0,
319
+ "style_loss": None,
320
+ "content_loss": None,
321
+ "result_url": None,
322
+ "error": None
323
+ }
324
+
325
+ # Start style transfer in the background
326
+ background_tasks.add_task(
327
+ run_style_transfer_task,
328
+ job_id,
329
+ content_path,
330
+ style_path,
331
+ output_path,
332
+ style_weight,
333
+ content_weight,
334
+ num_steps,
335
+ layer_weights_dict
336
+ )
337
+
338
+ return {
339
+ "job_id": job_id,
340
+ "status": "pending"
341
+ }
342
+
343
+ except Exception as e:
344
+ logger.error(f"Error creating style transfer: {str(e)}")
345
+ return CustomJSONResponse(
346
+ status_code=500,
347
+ content={"error": str(e)}
348
+ )
349
+
350
+ @app.get("/api/transfer/{job_id}")
351
+ async def get_transfer_status(job_id: str):
352
+ if job_id not in active_jobs and job_id not in job_queues:
353
+ return CustomJSONResponse(
354
+ status_code=404,
355
+ content={"error": "Job not found"}
356
+ )
357
+
358
+ # Try to get the latest status from the queue
359
+ if job_id in job_queues:
360
+ try:
361
+ # Get the latest status without removing it from the queue
362
+ status = job_queues[job_id].get_nowait()
363
+ job_queues[job_id].put_nowait(status) # Put it back
364
+ active_jobs[job_id] = status # Update active_jobs with latest status
365
+ except asyncio.QueueEmpty:
366
+ # If queue is empty, use the last known status from active_jobs
367
+ pass
368
+
369
+ job_status = active_jobs[job_id]
370
+
371
+ # Return appropriate response based on job status
372
+ if job_status["status"] == "completed" and job_status.get("result_url"):
373
+ # Clean up the queue for completed jobs
374
+ if job_id in job_queues:
375
+ del job_queues[job_id]
376
+ return {
377
+ "job_id": job_id,
378
+ "status": "completed",
379
+ "progress": 100,
380
+ "style_loss": job_status.get("style_loss"),
381
+ "content_loss": job_status.get("content_loss"),
382
+ "result_url": job_status["result_url"]
383
+ }
384
+ elif job_status["status"] == "failed":
385
+ # Clean up the queue for failed jobs
386
+ if job_id in job_queues:
387
+ del job_queues[job_id]
388
+ return {
389
+ "job_id": job_id,
390
+ "status": "failed",
391
+ "error": job_status.get("error", "Unknown error")
392
+ }
393
+ else:
394
+ return {
395
+ "job_id": job_id,
396
+ "status": job_status["status"],
397
+ "progress": job_status.get("progress", 0),
398
+ "style_loss": job_status.get("style_loss"),
399
+ "content_loss": job_status.get("content_loss")
400
+ }
401
+
402
+ @app.get("/api/health")
403
+ async def health_check():
404
+ return {"status": "ok"}
405
+
406
+ @app.get("/api/gallery")
407
+ async def get_gallery_items():
408
+ try:
409
+ gallery = load_gallery()
410
+ # Replace any infinity values before sending response
411
+ for item in gallery:
412
+ if 'bestLoss' in item and isinstance(item['bestLoss'], float) and math.isinf(item['bestLoss']):
413
+ item['bestLoss'] = 999999999 if item['bestLoss'] > 0 else -999999999
414
+ if 'styleLoss' in item and isinstance(item['styleLoss'], float) and math.isinf(item['styleLoss']):
415
+ item['styleLoss'] = 999999999 if item['styleLoss'] > 0 else -999999999
416
+ if 'contentLoss' in item and isinstance(item['contentLoss'], float) and math.isinf(item['contentLoss']):
417
+ item['contentLoss'] = 999999999 if item['contentLoss'] > 0 else -999999999
418
+ return gallery
419
+ except Exception as e:
420
+ logger.error(f"Error getting gallery items: {str(e)}")
421
+ return CustomJSONResponse(
422
+ status_code=500,
423
+ content={"error": str(e)}
424
+ )
425
+
426
+ @app.get("/api/gallery/{item_id}")
427
+ async def get_gallery_item(item_id: str):
428
+ try:
429
+ gallery = load_gallery()
430
+ item = next((item for item in gallery if item["id"] == item_id), None)
431
+ if item is None:
432
+ return CustomJSONResponse(status_code=404, content={"error": "Item not found"})
433
+
434
+ # Replace any infinity values before sending response
435
+ if 'bestLoss' in item and isinstance(item['bestLoss'], float) and math.isinf(item['bestLoss']):
436
+ item['bestLoss'] = 999999999 if item['bestLoss'] > 0 else -999999999
437
+ if 'styleLoss' in item and isinstance(item['styleLoss'], float) and math.isinf(item['styleLoss']):
438
+ item['styleLoss'] = 999999999 if item['styleLoss'] > 0 else -999999999
439
+ if 'contentLoss' in item and isinstance(item['contentLoss'], float) and math.isinf(item['contentLoss']):
440
+ item['contentLoss'] = 999999999 if item['contentLoss'] > 0 else -999999999
441
+
442
+ return item
443
+ except Exception as e:
444
+ logger.error(f"Error getting gallery item: {str(e)}")
445
+ return CustomJSONResponse(
446
+ status_code=500,
447
+ content={"error": str(e)}
448
+ )
449
+
450
+ @app.delete("/api/gallery/{item_id}")
451
+ async def delete_gallery_item(
452
+ item_id: str,
453
+ current_user: Dict[str, Any] = Depends(get_current_user),
454
+ ):
455
+ try:
456
+ gallery = load_gallery()
457
+ item_to_delete = next((item for item in gallery if item["id"] == item_id), None)
458
+
459
+ if not item_to_delete:
460
+ return CustomJSONResponse(
461
+ status_code=404,
462
+ content={"error": "Item not found"}
463
+ )
464
+
465
+ # Remove from gallery first
466
+ gallery = [item for item in gallery if item["id"] != item_id]
467
+ save_gallery(gallery)
468
+
469
+ # Attempt to delete artifacts from dataset
470
+ try:
471
+ storage_client.delete_run_artifacts(item_to_delete)
472
+ except Exception as e:
473
+ logger.error(f"Error deleting dataset artifacts for {item_id}: {str(e)}")
474
+
475
+ return {"status": "success"}
476
+ except Exception as e:
477
+ logger.error(f"Error deleting gallery item: {str(e)}")
478
+ return CustomJSONResponse(
479
+ status_code=500,
480
+ content={"error": str(e)}
481
+ )
482
+
483
+ # Authentication endpoints
484
+ @app.post("/api/auth/login")
485
+ async def login(login_request: LoginRequest):
486
+ """
487
+ Login with email/password or master password.
488
+ """
489
+ logger.info(f"Login attempt - email: {login_request.email}, has_master_password: {bool(login_request.master_password)}")
490
+ try:
491
+ # Check master password first (doesn't require email or password)
492
+ if login_request.master_password and MASTER_PASSWORD and login_request.master_password == MASTER_PASSWORD:
493
+ access_token = create_access_token(data={"email": None, "is_master": True})
494
+ return {"access_token": access_token, "token_type": "bearer", "is_master": True}
495
+
496
+ # Check user email/password (requires both email and password)
497
+ if login_request.email and login_request.password:
498
+ user = auth_storage.get_user(login_request.email)
499
+ if user and verify_password(login_request.password, user["password_hash"]):
500
+ access_token = create_access_token(data={"email": login_request.email, "is_master": False})
501
+ return {"access_token": access_token, "token_type": "bearer", "is_master": False, "email": login_request.email}
502
+
503
+ raise HTTPException(
504
+ status_code=status.HTTP_401_UNAUTHORIZED,
505
+ detail="Incorrect email/password or master password"
506
+ )
507
+ except HTTPException:
508
+ raise
509
+ except Exception as e:
510
+ logger.error(f"Error in login: {str(e)}")
511
+ raise HTTPException(
512
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
513
+ detail="Login failed"
514
+ )
515
+
516
+ @app.post("/api/auth/requests")
517
+ async def submit_permission_request(request: PermissionRequest):
518
+ """
519
+ Submit a permission request. Public endpoint.
520
+ """
521
+ try:
522
+ request_id = auth_storage.add_request(
523
+ name=request.name,
524
+ email=request.email,
525
+ reason=request.reason
526
+ )
527
+
528
+ # Get the request to get timestamp
529
+ req = auth_storage.get_request(request_id)
530
+ if req:
531
+ # Send email notification to admin
532
+ email_service.send_permission_request_notification(
533
+ name=request.name,
534
+ email=request.email,
535
+ reason=request.reason,
536
+ timestamp=req.get("timestamp", "")
537
+ )
538
+
539
+ return {"request_id": request_id, "status": "submitted"}
540
+ except Exception as e:
541
+ logger.error(f"Error submitting permission request: {str(e)}")
542
+ raise HTTPException(
543
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
544
+ detail="Failed to submit request"
545
+ )
546
+
547
+ def verify_master_password(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
548
+ """Verify that the current user is using master password."""
549
+ if not user.get("is_master"):
550
+ raise HTTPException(
551
+ status_code=status.HTTP_403_FORBIDDEN,
552
+ detail="Master password required"
553
+ )
554
+ return user
555
+
556
+ @app.get("/api/auth/requests")
557
+ async def list_permission_requests(
558
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
559
+ ):
560
+ """List all permission requests. Admin only."""
561
+ try:
562
+ requests = auth_storage.load_requests()
563
+ return {"requests": requests}
564
+ except Exception as e:
565
+ logger.error(f"Error listing requests: {str(e)}")
566
+ raise HTTPException(
567
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
568
+ detail="Failed to list requests"
569
+ )
570
+
571
+ @app.post("/api/auth/requests/{request_id}/approve")
572
+ async def approve_request(
573
+ request_id: str,
574
+ approve_data: ApproveRequestData,
575
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
576
+ ):
577
+ """Approve a permission request and create user account. Admin only."""
578
+ try:
579
+ req = auth_storage.get_request(request_id)
580
+ if not req:
581
+ raise HTTPException(status_code=404, detail="Request not found")
582
+
583
+ if req.get("status") != "pending":
584
+ raise HTTPException(status_code=400, detail="Request already processed")
585
+
586
+ # Generate password if not provided
587
+ password = approve_data.password or secrets.token_urlsafe(12)
588
+ password_hash = get_password_hash(password)
589
+
590
+ # Create user account
591
+ try:
592
+ auth_storage.add_user(email=req["email"], password_hash=password_hash)
593
+ except ValueError as e:
594
+ # User might already exist
595
+ logger.warning(f"User creation warning: {str(e)}")
596
+
597
+ # Update request status
598
+ auth_storage.update_request_status(request_id, "approved")
599
+
600
+ # Send approval email
601
+ email_service.send_approval_email(
602
+ user_email=req["email"],
603
+ user_name=req["name"],
604
+ password=password
605
+ )
606
+
607
+ return {"status": "approved", "email": req["email"]}
608
+ except HTTPException:
609
+ raise
610
+ except Exception as e:
611
+ logger.error(f"Error approving request: {str(e)}")
612
+ raise HTTPException(
613
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
614
+ detail="Failed to approve request"
615
+ )
616
+
617
+ @app.post("/api/auth/requests/{request_id}/reject")
618
+ async def reject_request(
619
+ request_id: str,
620
+ reject_data: RejectRequestData,
621
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
622
+ ):
623
+ """Reject a permission request. Admin only."""
624
+ try:
625
+ req = auth_storage.get_request(request_id)
626
+ if not req:
627
+ raise HTTPException(status_code=404, detail="Request not found")
628
+
629
+ if req.get("status") != "pending":
630
+ raise HTTPException(status_code=400, detail="Request already processed")
631
+
632
+ # Update request status
633
+ auth_storage.update_request_status(request_id, "rejected", reject_data.reason)
634
+
635
+ # Send rejection email
636
+ email_service.send_rejection_email(
637
+ user_email=req["email"],
638
+ user_name=req["name"],
639
+ reason=reject_data.reason
640
+ )
641
+
642
+ return {"status": "rejected"}
643
+ except HTTPException:
644
+ raise
645
+ except Exception as e:
646
+ logger.error(f"Error rejecting request: {str(e)}")
647
+ raise HTTPException(
648
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
649
+ detail="Failed to reject request"
650
+ )
651
+
652
+ @app.delete("/api/auth/requests/{request_id}")
653
+ async def delete_request(
654
+ request_id: str,
655
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
656
+ ):
657
+ """Delete a permission request without sending email. Admin only."""
658
+ try:
659
+ auth_storage.delete_request(request_id)
660
+ return {"status": "deleted"}
661
+ except Exception as e:
662
+ logger.error(f"Error deleting request: {str(e)}")
663
+ raise HTTPException(
664
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
665
+ detail="Failed to delete request"
666
+ )
667
+
668
+ @app.post("/api/auth/users")
669
+ async def create_user(
670
+ user_request: CreateUserRequest,
671
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
672
+ ):
673
+ """Create a new user. Admin only."""
674
+ try:
675
+ password_hash = get_password_hash(user_request.password)
676
+ auth_storage.add_user(email=user_request.email, password_hash=password_hash)
677
+ return {"status": "created", "email": user_request.email}
678
+ except ValueError as e:
679
+ raise HTTPException(status_code=400, detail=str(e))
680
+ except Exception as e:
681
+ logger.error(f"Error creating user: {str(e)}")
682
+ raise HTTPException(
683
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
684
+ detail="Failed to create user"
685
+ )
686
+
687
+ @app.get("/api/auth/users")
688
+ async def list_users(
689
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
690
+ ):
691
+ """List all users. Admin only."""
692
+ try:
693
+ users = auth_storage.load_users()
694
+ # Don't return password hashes
695
+ user_list = [{"email": user["email"]} for user in users]
696
+ return {"users": user_list}
697
+ except Exception as e:
698
+ logger.error(f"Error listing users: {str(e)}")
699
+ raise HTTPException(
700
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
701
+ detail="Failed to list users"
702
+ )
703
+
704
+ @app.delete("/api/auth/users/{email}")
705
+ async def delete_user(
706
+ email: str,
707
+ admin_user: Dict[str, Any] = Depends(verify_master_password)
708
+ ):
709
+ """Delete a user. Admin only."""
710
+ try:
711
+ auth_storage.delete_user(email)
712
+ return {"status": "deleted", "email": email}
713
+ except Exception as e:
714
+ logger.error(f"Error deleting user: {str(e)}")
715
+ raise HTTPException(
716
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
717
+ detail="Failed to delete user"
718
+ )
719
+
720
+ if __name__ == "__main__":
721
+ import uvicorn
722
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ pillow>=10.0.0
4
+ numpy>=1.25.2
5
+ fastapi==0.103.1
6
+ uvicorn==0.23.2
7
+ python-multipart==0.0.6
8
+ aiofiles==23.2.1
9
+ matplotlib>=3.7.2
10
+ requests==2.31.0
11
+ python-dotenv==1.0.0
12
+ huggingface_hub>=0.26.0
13
+ python-jose[cryptography]>=3.3.0
14
+ passlib[bcrypt]>=1.7.4
run.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ import os
3
+
4
+ # Create necessary directories
5
+ os.makedirs("uploads", exist_ok=True)
6
+ os.makedirs("results", exist_ok=True)
7
+
8
+ if __name__ == "__main__":
9
+ print("Starting Neural Style Transfer API")
10
+ uvicorn.run("main:app", host="0.0.0.0", port=7860)
storage_hf.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete, create_commit, hf_hub_url
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ GALLERY_FILE_PATH = "gallery/gallery.json"
12
+
13
+
14
+ def build_dataset_resolve_url(repo_id: str, path_in_repo: str, revision: str = "main") -> str:
15
+ """
16
+ Build a CDN-resolved URL for a file stored in a Hugging Face dataset repo.
17
+ """
18
+ return hf_hub_url(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", revision=revision)
19
+
20
+
21
+ class HFStorageClient:
22
+ """
23
+ Simple helper around huggingface_hub for storing run artifacts and gallery metadata
24
+ in a Dataset repository.
25
+
26
+ Repo format:
27
+ - runs/YYYY/MM/DD/<job_id>/content.jpg
28
+ - runs/YYYY/MM/DD/<job_id>/style.jpg
29
+ - runs/YYYY/MM/DD/<job_id>/result.jpg
30
+ - gallery/gallery.json
31
+ """
32
+
33
+ def __init__(self, dataset_repo: str, hf_token: Optional[str] = None, revision: str = "main"):
34
+ if not dataset_repo:
35
+ raise ValueError("HF_DATASET_REPO is not set. Please configure the dataset repository id.")
36
+ self.dataset_repo = dataset_repo
37
+ self.revision = revision
38
+ self.api = HfApi(token=hf_token) if hf_token else HfApi()
39
+
40
+ def load_gallery(self) -> List[Dict[str, Any]]:
41
+ """
42
+ Download and parse gallery.json from the dataset. If missing, return [].
43
+ """
44
+ try:
45
+ # Try to get the raw file content via the hub URL
46
+ url = build_dataset_resolve_url(self.dataset_repo, GALLERY_FILE_PATH, self.revision)
47
+ import requests # local import to avoid hard dependency elsewhere
48
+ headers = {}
49
+ if self.api.token:
50
+ headers["Authorization"] = f"Bearer {self.api.token}"
51
+ resp = requests.get(url, timeout=10, headers=headers)
52
+ if resp.status_code == 200:
53
+ return resp.json()
54
+ logger.info("Gallery not found at %s (status %s). Initializing empty gallery.", url, resp.status_code)
55
+ return []
56
+ except Exception as e:
57
+ logger.error("Failed to load gallery from HF: %s", str(e))
58
+ return []
59
+
60
+ def save_gallery(self, gallery: List[Dict[str, Any]]) -> None:
61
+ """
62
+ Commit a new version of gallery.json to the dataset repo.
63
+ """
64
+ try:
65
+ payload = json.dumps(gallery, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
66
+ operations = [
67
+ CommitOperationAdd(path_in_repo=GALLERY_FILE_PATH, path_or_fileobj=payload)
68
+ ]
69
+ create_commit(
70
+ repo_id=self.dataset_repo,
71
+ repo_type="dataset",
72
+ operations=operations,
73
+ commit_message="Update gallery.json",
74
+ revision=self.revision,
75
+ token=self.api.token,
76
+ )
77
+ except Exception as e:
78
+ logger.error("Failed to save gallery to HF: %s", str(e))
79
+ raise
80
+
81
+ def upload_file(self, local_path: str, dst_path: str) -> str:
82
+ """
83
+ Upload a local file to the dataset repo at dst_path. Returns the path_in_repo.
84
+ """
85
+ if not os.path.exists(local_path):
86
+ raise FileNotFoundError(local_path)
87
+
88
+ try:
89
+ with open(local_path, "rb") as f:
90
+ operations = [
91
+ CommitOperationAdd(path_in_repo=dst_path, path_or_fileobj=f)
92
+ ]
93
+ create_commit(
94
+ repo_id=self.dataset_repo,
95
+ repo_type="dataset",
96
+ operations=operations,
97
+ commit_message=f"Upload {dst_path}",
98
+ revision=self.revision,
99
+ token=self.api.token,
100
+ )
101
+ return dst_path
102
+ except Exception as e:
103
+ logger.error("Failed to upload %s to HF at %s: %s", local_path, dst_path, str(e))
104
+ raise
105
+
106
+ def delete_run_artifacts(self, gallery_item: Dict[str, Any]) -> None:
107
+ """
108
+ Attempt to delete the three image artifacts associated with a run.
109
+ This parses resolve URLs to determine paths in repo.
110
+ """
111
+ def extract_path(url: Optional[str]) -> Optional[str]:
112
+ if not url:
113
+ return None
114
+ marker = "/resolve/"
115
+ if marker in url:
116
+ try:
117
+ # url ends with .../resolve/<rev>/<path_in_repo>
118
+ parts = url.split(marker, 1)[1].split("/", 1)
119
+ if len(parts) == 2:
120
+ return parts[1]
121
+ except Exception:
122
+ return None
123
+ return None
124
+
125
+ paths: List[str] = []
126
+ for key in ("contentImageUrl", "styleImageUrl", "resultImageUrl"):
127
+ p = extract_path(gallery_item.get(key))
128
+ if p:
129
+ paths.append(p)
130
+
131
+ if not paths:
132
+ return
133
+
134
+ try:
135
+ operations = [CommitOperationDelete(path) for path in paths]
136
+ create_commit(
137
+ repo_id=self.dataset_repo,
138
+ repo_type="dataset",
139
+ operations=operations,
140
+ commit_message=f"Delete artifacts for run {gallery_item.get('id', '')}",
141
+ revision=self.revision,
142
+ token=self.api.token,
143
+ )
144
+ except Exception as e:
145
+ logger.error("Failed to delete artifacts %s: %s", paths, str(e))
146
+
style_transfer.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import torchvision.models as models
8
+ import copy
9
+ import time
10
+ import os
11
+ import io
12
+
13
+ # Check if GPU is available
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Using device: {device}")
16
+
17
+ # Image loading and preprocessing
18
+ def image_loader(image_path, imsize=512):
19
+ loader = transforms.Compose([
20
+ transforms.Resize(imsize), # Scale imported image
21
+ transforms.CenterCrop(imsize), # Ensure square size
22
+ transforms.ToTensor(), # Transform into torch tensor
23
+ transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x) # Convert grayscale to RGB if needed
24
+ ])
25
+
26
+ image = Image.open(image_path).convert('RGB') # Ensure image is RGB
27
+ # Add batch dimension (1, 3, h, w)
28
+ image = loader(image).unsqueeze(0)
29
+ return image.to(device, torch.float)
30
+
31
+ def load_image_from_bytes(image_bytes, imsize=512):
32
+ loader = transforms.Compose([
33
+ transforms.Resize(imsize),
34
+ transforms.CenterCrop(imsize),
35
+ transforms.ToTensor(),
36
+ transforms.Lambda(lambda x: x.repeat(1, 1, 1) if x.size(0) == 1 else x)
37
+ ])
38
+
39
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
40
+ image = loader(image).unsqueeze(0)
41
+ return image.to(device, torch.float)
42
+
43
+ # Content Loss: Measures content similarity
44
+ class ContentLoss(nn.Module):
45
+ def __init__(self, target):
46
+ super(ContentLoss, self).__init__()
47
+ # Detach the target content from the tree used to dynamically compute gradients
48
+ self.target = target.detach()
49
+
50
+ def forward(self, input):
51
+ self.loss = F.mse_loss(input, self.target)
52
+ return input
53
+
54
+ # Gram matrix calculation for style representation
55
+ def gram_matrix(input):
56
+ batch_size, n_channels, height, width = input.size()
57
+ features = input.view(batch_size * n_channels, height * width)
58
+ G = torch.mm(features, features.t())
59
+ # Normalize by total number of elements
60
+ return G.div(batch_size * n_channels * height * width)
61
+
62
+ # Style Loss: Measures style similarity using Gram matrices
63
+ class StyleLoss(nn.Module):
64
+ def __init__(self, target_feature):
65
+ super(StyleLoss, self).__init__()
66
+ self.target = gram_matrix(target_feature).detach()
67
+ self.weight = 1.0 # Default weight for this layer
68
+
69
+ def forward(self, input):
70
+ G = gram_matrix(input)
71
+ self.loss = F.mse_loss(G, self.target)
72
+ return input
73
+
74
+ # Normalization layer for VGG compatibility
75
+ class Normalization(nn.Module):
76
+ def __init__(self, mean, std):
77
+ super(Normalization, self).__init__()
78
+ # View the mean and std as 1x3x1x1 tensors
79
+ self.mean = mean.clone().detach().view(-1, 1, 1).to(device)
80
+ self.std = std.clone().detach().view(-1, 1, 1).to(device)
81
+
82
+ def forward(self, img):
83
+ # Normalize img
84
+ return (img - self.mean) / self.std
85
+
86
+ # Build model with content and style losses
87
+ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
88
+ style_img, content_img,
89
+ content_layers=['conv_4'],
90
+ style_layers=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'],
91
+ layer_weights=None):
92
+ normalization = Normalization(normalization_mean, normalization_std)
93
+
94
+ # Set default layer weights if not provided
95
+ if layer_weights is None:
96
+ layer_weights = {layer: 1.0 for layer in style_layers}
97
+
98
+ # Lists to keep track of losses
99
+ content_losses = []
100
+ style_losses = []
101
+
102
+ # Create a "sequential" module with added content/style loss layers
103
+ model = nn.Sequential(normalization)
104
+
105
+ i = 0 # Increment for each conv layer
106
+ for layer in cnn.children():
107
+ if isinstance(layer, nn.Conv2d):
108
+ i += 1
109
+ name = f'conv_{i}'
110
+ elif isinstance(layer, nn.ReLU):
111
+ name = f'relu_{i}'
112
+ # Replace in-place version with out-of-place
113
+ layer = nn.ReLU(inplace=False)
114
+ elif isinstance(layer, nn.MaxPool2d):
115
+ name = f'pool_{i}'
116
+ elif isinstance(layer, nn.BatchNorm2d):
117
+ name = f'bn_{i}'
118
+ else:
119
+ raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
120
+
121
+ model.add_module(name, layer)
122
+
123
+ # Add content loss
124
+ if name in content_layers:
125
+ # Add content loss:
126
+ target = model(content_img).detach()
127
+ content_loss = ContentLoss(target)
128
+ model.add_module(f"content_loss_{i}", content_loss)
129
+ content_losses.append(content_loss)
130
+
131
+ # Add style loss
132
+ if name in style_layers:
133
+ # Add style loss:
134
+ target_feature = model(style_img).detach()
135
+ style_loss = StyleLoss(target_feature)
136
+
137
+ # Apply customized layer weight
138
+ style_loss.weight = layer_weights.get(name, 1.0)
139
+
140
+ model.add_module(f"style_loss_{i}", style_loss)
141
+ style_losses.append(style_loss)
142
+
143
+ # Trim off the layers after the last content and style losses
144
+ for i in range(len(model) - 1, -1, -1):
145
+ if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
146
+ break
147
+
148
+ model = model[:(i + 1)]
149
+
150
+ return model, style_losses, content_losses
151
+
152
+ # Optimization loop for style transfer
153
+ def run_style_transfer(cnn, normalization_mean, normalization_std,
154
+ content_img, style_img, input_img, num_steps=300,
155
+ style_weight=1000000, content_weight=1,
156
+ layer_weights=None, progress_callback=None):
157
+ """Run the style transfer."""
158
+ num_steps = min(num_steps, 400)
159
+
160
+ print('Building the style transfer model...')
161
+ model, style_losses, content_losses = get_style_model_and_losses(
162
+ cnn, normalization_mean, normalization_std,
163
+ style_img, content_img,
164
+ layer_weights=layer_weights
165
+ )
166
+
167
+ # We want to optimize the input image only
168
+ input_img.requires_grad_(True)
169
+ model.eval() # We don't need gradients for the model parameters
170
+ model.requires_grad_(False)
171
+
172
+ optimizer = optim.LBFGS([input_img])
173
+ best_img = None
174
+ best_loss = float('inf')
175
+ prev_loss = float('inf')
176
+ current_step = 0
177
+
178
+ start_time = time.time()
179
+
180
+ # Function to be used with optimizer
181
+ def closure():
182
+ nonlocal current_step
183
+ # Correct the values of updated input image
184
+ with torch.no_grad():
185
+ input_img.clamp_(0, 1)
186
+
187
+ optimizer.zero_grad()
188
+ model(input_img)
189
+ style_score = 0
190
+ content_score = 0
191
+
192
+ for sl in style_losses:
193
+ # Apply per-layer weight
194
+ style_score += sl.loss * sl.weight
195
+
196
+ for cl in content_losses:
197
+ content_score += cl.loss
198
+
199
+ style_score *= style_weight
200
+ content_score *= content_weight
201
+
202
+ loss = style_score + content_score
203
+ loss.backward()
204
+
205
+ current_step += 1
206
+ if current_step % 50 == 0:
207
+ elapsed = time.time() - start_time
208
+ print(f"Iteration: {current_step}, Style Loss: {style_score.item():.2f}, Content Loss: {content_score.item():.2f}, Total Loss: {loss.item():.2f}, Time: {elapsed:.1f}s")
209
+
210
+ if progress_callback:
211
+ progress = {
212
+ 'iteration': current_step,
213
+ 'style_loss': style_score.item(),
214
+ 'content_loss': content_score.item(),
215
+ 'elapsed_time': elapsed
216
+ }
217
+ progress_callback(progress)
218
+
219
+ # Save best result so far
220
+ nonlocal best_loss, best_img, prev_loss
221
+ current_loss = loss.item()
222
+
223
+ if current_loss < best_loss:
224
+ best_loss = current_loss
225
+ best_img = input_img.clone()
226
+
227
+ # Update previous loss for next iteration
228
+ prev_loss = current_loss
229
+ return loss
230
+
231
+ # Run optimization with early stopping
232
+ while current_step < num_steps:
233
+ optimizer.step(closure)
234
+
235
+ # Check stopping conditions after minimum iterations
236
+ if current_step >= 50 and prev_loss > 1000:
237
+ print(f"Stopping early at iteration {current_step} due to high loss: {prev_loss:.2f}")
238
+ break
239
+
240
+ # A final correction
241
+ with torch.no_grad():
242
+ input_img.clamp_(0, 1)
243
+
244
+ print(f"Total time: {time.time() - start_time:.1f}s")
245
+ print(f"Best loss achieved: {best_loss:.2f}")
246
+
247
+ # Return both the final and best image (often the same)
248
+ return input_img, best_img, best_loss
249
+
250
+ # Save tensor as image
251
+ def save_image(tensor, path):
252
+ image = tensor.cpu().clone()
253
+ image = image.squeeze(0) # Remove batch dimension
254
+ image = transforms.ToPILImage()(image)
255
+ image.save(path)
256
+ return image
257
+
258
+ # Main style transfer function
259
+ def transfer_style(content_path, style_path, output_path, style_weight=1000000,
260
+ content_weight=1, num_steps=300, layer_weights=None,
261
+ progress_callback=None):
262
+ """
263
+ Perform style transfer and save the result
264
+
265
+ Args:
266
+ content_path: Path to content image
267
+ style_path: Path to style image
268
+ output_path: Where to save the output image
269
+ style_weight: Weight for style loss
270
+ content_weight: Weight for content loss
271
+ num_steps: Number of optimization steps
272
+ layer_weights: Dictionary of weights for each style layer
273
+ progress_callback: Function to call for progress updates
274
+
275
+ Returns:
276
+ Tuple of (output_path, best_loss)
277
+ """
278
+ # Load images
279
+ content_img = image_loader(content_path)
280
+ style_img = image_loader(style_path)
281
+
282
+ # Start with content image for faster convergence
283
+ input_img = content_img.clone()
284
+
285
+ # Load VGG19 for feature extraction
286
+ cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(device).eval()
287
+
288
+ # Mean and std for normalization (from ImageNet)
289
+ cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
290
+ cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
291
+
292
+ # Run style transfer
293
+ output, best_output, best_loss = run_style_transfer(
294
+ cnn,
295
+ cnn_normalization_mean,
296
+ cnn_normalization_std,
297
+ content_img,
298
+ style_img,
299
+ input_img,
300
+ num_steps=num_steps,
301
+ style_weight=style_weight,
302
+ content_weight=content_weight,
303
+ layer_weights=layer_weights,
304
+ progress_callback=progress_callback
305
+ )
306
+
307
+ # Save result and return path
308
+ save_image(best_output, output_path)
309
+
310
+ return output_path, best_loss