|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form, Depends, Body |
|
|
from typing import Optional, Tuple |
|
|
from fastapi.responses import FileResponse |
|
|
from huggingface_hub import hf_hub_download |
|
|
import uuid |
|
|
import os |
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import httpx |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
import numpy as np |
|
|
from pydantic import BaseModel, EmailStr |
|
|
from app.database import ( |
|
|
get_database, |
|
|
log_api_call, |
|
|
log_image_upload, |
|
|
log_colorization, |
|
|
log_media_click, |
|
|
close_connection, |
|
|
) |
|
|
try: |
|
|
import firebase_admin |
|
|
from firebase_admin import auth as firebase_auth, app_check, credentials |
|
|
except ImportError: |
|
|
firebase_admin = None |
|
|
firebase_auth = None |
|
|
app_check = None |
|
|
credentials = None |
|
|
|
|
|
|
|
|
try: |
|
|
from app.colorizers import eccv16, siggraph17 |
|
|
from app.colorizers.util import preprocess_img, postprocess_tens |
|
|
CCO_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"⚠️ CCO colorizers not available: {e}") |
|
|
CCO_AVAILABLE = False |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Text-Guided Image Colorization API") |
|
|
|
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if firebase_admin: |
|
|
try: |
|
|
firebase_json = os.getenv("FIREBASE_CREDENTIALS") |
|
|
|
|
|
if firebase_json: |
|
|
print("🔥 Loading Firebase credentials from ENV...") |
|
|
firebase_dict = json.loads(firebase_json) |
|
|
cred = credentials.Certificate(firebase_dict) |
|
|
firebase_admin.initialize_app(cred) |
|
|
else: |
|
|
print("⚠️ No Firebase credentials found. Firebase disabled.") |
|
|
|
|
|
except Exception as e: |
|
|
print("❌ Firebase initialization failed:", e) |
|
|
else: |
|
|
print("⚠️ Firebase Admin SDK not available. Firebase features disabled.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UPLOAD_DIR = "/tmp/uploads" |
|
|
RESULTS_DIR = "/tmp/results" |
|
|
COMPRESSED_DIR = "/tmp/compressed" |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
os.makedirs(COMPRESSED_DIR, exist_ok=True) |
|
|
|
|
|
MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = "Hammad712/GAN-Colorization-Model" |
|
|
MODEL_FILENAME = "generator.pt" |
|
|
|
|
|
print("⬇️ Downloading GAN model...") |
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
|
|
|
|
print("📦 Loading GAN model weights...") |
|
|
state_dict = torch.load(model_path, map_location="cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cco_models = {} |
|
|
if CCO_AVAILABLE: |
|
|
print("📦 Loading CCO models...") |
|
|
try: |
|
|
cco_models["eccv16"] = eccv16(pretrained=True).eval() |
|
|
cco_models["siggraph17"] = siggraph17(pretrained=True).eval() |
|
|
print("✅ CCO models loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load CCO models: {e}") |
|
|
CCO_AVAILABLE = False |
|
|
|
|
|
def colorize_image_gan(img: Image.Image): |
|
|
""" GAN colorizer (dummy implementation - replace with real model.predict) """ |
|
|
transform = transforms.ToTensor() |
|
|
tensor = transform(img.convert("L")).unsqueeze(0) |
|
|
tensor = tensor.repeat(1, 3, 1, 1) |
|
|
output_img = transforms.ToPILImage()(tensor.squeeze()) |
|
|
return output_img |
|
|
|
|
|
def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"): |
|
|
""" CCO colorizer using eccv16 or siggraph17 model """ |
|
|
if not CCO_AVAILABLE: |
|
|
raise ValueError("CCO models are not available") |
|
|
|
|
|
if model_name not in ["eccv16", "siggraph17"]: |
|
|
model_name = "eccv16" |
|
|
|
|
|
model = cco_models.get(model_name) |
|
|
if model is None: |
|
|
raise ValueError(f"CCO model '{model_name}' not loaded") |
|
|
|
|
|
|
|
|
if img.mode != "RGB": |
|
|
img = img.convert("RGB") |
|
|
|
|
|
|
|
|
oimg = np.asarray(img) |
|
|
if oimg.ndim == 2: |
|
|
oimg = np.tile(oimg[:,:,None], 3) |
|
|
|
|
|
|
|
|
(tens_l_orig, tens_l_rs) = preprocess_img(oimg) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out_ab = model(tens_l_rs) |
|
|
|
|
|
|
|
|
output_rgb = postprocess_tens(tens_l_orig, out_ab) |
|
|
|
|
|
|
|
|
output_rgb = np.clip(output_rgb, 0, 1) |
|
|
output_array = (output_rgb * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
output_img = Image.fromarray(output_array, 'RGB') |
|
|
return output_img |
|
|
|
|
|
def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"): |
|
|
""" |
|
|
Colorize image using specified model |
|
|
|
|
|
Args: |
|
|
img: PIL Image to colorize |
|
|
model_type: "gan" or "cco" |
|
|
cco_model: "eccv16" or "siggraph17" (only used if model_type is "cco") |
|
|
|
|
|
Returns: |
|
|
Colorized PIL Image |
|
|
""" |
|
|
if model_type == "cco": |
|
|
return colorize_image_cco(img, cco_model) |
|
|
else: |
|
|
return colorize_image_gan(img) |
|
|
|
|
|
def compress_image_to_target_size(img: Image.Image, target_size_mb: float = 2.5, min_size_mb: float = 2.0, max_size_mb: float = 3.0) -> Tuple[Image.Image, str]: |
|
|
""" |
|
|
Compress image to target file size (2-3MB) by adjusting quality and dimensions |
|
|
|
|
|
Args: |
|
|
img: PIL Image to compress |
|
|
target_size_mb: Target file size in MB (default: 2.5MB) |
|
|
min_size_mb: Minimum acceptable file size in MB (default: 2.0MB) |
|
|
max_size_mb: Maximum acceptable file size in MB (default: 3.0MB) |
|
|
|
|
|
Returns: |
|
|
Tuple of (compressed PIL Image, path to saved compressed image) |
|
|
""" |
|
|
import tempfile |
|
|
|
|
|
|
|
|
if img.mode != "RGB": |
|
|
img = img.convert("RGB") |
|
|
|
|
|
|
|
|
original_size = img.size |
|
|
current_img = img.copy() |
|
|
|
|
|
|
|
|
quality = 85 |
|
|
scale_factor = 1.0 |
|
|
|
|
|
|
|
|
for attempt in range(20): |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: |
|
|
temp_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
if scale_factor < 1.0: |
|
|
new_width = int(original_size[0] * scale_factor) |
|
|
new_height = int(original_size[1] * scale_factor) |
|
|
resized_img = current_img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
else: |
|
|
resized_img = current_img |
|
|
|
|
|
|
|
|
resized_img.save(temp_path, "JPEG", quality=quality, optimize=True) |
|
|
|
|
|
|
|
|
file_size_bytes = os.path.getsize(temp_path) |
|
|
file_size_mb = file_size_bytes / (1024 * 1024) |
|
|
|
|
|
logger.info("Compression attempt %d: quality=%d, scale=%.2f, size=%.2f MB", |
|
|
attempt + 1, quality, scale_factor, file_size_mb) |
|
|
|
|
|
|
|
|
if min_size_mb <= file_size_mb <= max_size_mb: |
|
|
logger.info("Target size achieved: %.2f MB", file_size_mb) |
|
|
return resized_img, temp_path |
|
|
|
|
|
|
|
|
if file_size_mb > max_size_mb: |
|
|
if quality > 30: |
|
|
quality -= 5 |
|
|
elif scale_factor > 0.5: |
|
|
scale_factor -= 0.05 |
|
|
quality = 75 |
|
|
else: |
|
|
|
|
|
logger.warning("Reached minimum compression, file size: %.2f MB", file_size_mb) |
|
|
return resized_img, temp_path |
|
|
|
|
|
|
|
|
elif file_size_mb < min_size_mb: |
|
|
if file_size_mb < 1.0: |
|
|
|
|
|
if quality < 90: |
|
|
quality += 5 |
|
|
elif scale_factor < 1.0: |
|
|
|
|
|
scale_factor = min(1.0, scale_factor + 0.1) |
|
|
elif quality < 95: |
|
|
|
|
|
quality += 2 |
|
|
if quality > 95: |
|
|
quality = 95 |
|
|
else: |
|
|
|
|
|
logger.info("File size %.2f MB is below minimum but at max quality, accepting", file_size_mb) |
|
|
return resized_img, temp_path |
|
|
|
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_path) and attempt < 19: |
|
|
try: |
|
|
os.unlink(temp_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
return resized_img, temp_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize MongoDB on startup""" |
|
|
try: |
|
|
db = get_database() |
|
|
if db is not None: |
|
|
print("✅ MongoDB initialized successfully!") |
|
|
except Exception as e: |
|
|
print(f"⚠️ MongoDB initialization failed: {e}") |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
close_connection() |
|
|
print("Application shutdown") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(request: Request): |
|
|
response = { |
|
|
"status": "healthy", |
|
|
"model_loaded": True, |
|
|
"model_type": "hf_inference_api", |
|
|
"provider": "fal-ai" |
|
|
} |
|
|
|
|
|
|
|
|
log_api_call( |
|
|
endpoint="/health", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
response_data=response, |
|
|
ip_address=request.client.host if request.client else None |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RegisterRequest(BaseModel): |
|
|
email: EmailStr |
|
|
password: str |
|
|
display_name: Optional[str] = None |
|
|
|
|
|
class LoginRequest(BaseModel): |
|
|
email: EmailStr |
|
|
password: str |
|
|
|
|
|
class TokenResponse(BaseModel): |
|
|
id_token: str |
|
|
refresh_token: Optional[str] = None |
|
|
expires_in: int |
|
|
token_type: str = "Bearer" |
|
|
user: dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_bearer_token(authorization_header: Optional[str]) -> Optional[str]: |
|
|
if not authorization_header: |
|
|
return None |
|
|
parts = authorization_header.split(" ", 1) |
|
|
if len(parts) == 2 and parts[0].lower() == "bearer": |
|
|
return parts[1].strip() |
|
|
return None |
|
|
|
|
|
async def verify_request(request: Request): |
|
|
""" |
|
|
Verify Firebase authentication. |
|
|
Priority: |
|
|
1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method |
|
|
2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints |
|
|
""" |
|
|
if not firebase_admin: |
|
|
return True |
|
|
|
|
|
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
|
|
return True |
|
|
|
|
|
|
|
|
app_check_token = request.headers.get("X-Firebase-AppCheck") or request.headers.get("x-firebase-appcheck") |
|
|
if app_check_token and app_check: |
|
|
try: |
|
|
app_check_claims = app_check.verify_token(app_check_token) |
|
|
logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("App Check token verification failed: %s", str(e)) |
|
|
|
|
|
|
|
|
bearer = _extract_bearer_token(request.headers.get("Authorization")) |
|
|
if bearer and firebase_auth: |
|
|
try: |
|
|
decoded = firebase_auth.verify_id_token(bearer) |
|
|
request.state.user = decoded |
|
|
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("Auth token verification failed: %s", str(e)) |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
from firebase_admin import app_check |
|
|
from fastapi import HTTPException |
|
|
import os |
|
|
|
|
|
def verify_app_check_token( |
|
|
token: str | None, |
|
|
*, |
|
|
required: bool = False |
|
|
): |
|
|
""" |
|
|
If required=False: |
|
|
- Missing token is allowed |
|
|
- Invalid token is rejected |
|
|
If required=True: |
|
|
- Missing OR invalid token is rejected |
|
|
""" |
|
|
|
|
|
|
|
|
if not token: |
|
|
if required: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Firebase App Check token missing" |
|
|
) |
|
|
return True |
|
|
|
|
|
|
|
|
try: |
|
|
app_check.verify_token(token) |
|
|
return True |
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Invalid Firebase App Check token" |
|
|
) |
|
|
|
|
|
|
|
|
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]: |
|
|
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click).""" |
|
|
if supplied_user_id and supplied_user_id.strip(): |
|
|
return supplied_user_id.strip() |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/auth/register", response_model=TokenResponse) |
|
|
async def register_user(user_data: RegisterRequest): |
|
|
""" |
|
|
Register a new user with email and password. |
|
|
Returns Firebase ID token for immediate use. |
|
|
""" |
|
|
if not firebase_admin: |
|
|
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available") |
|
|
|
|
|
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
|
|
raise HTTPException(status_code=503, detail="Firebase not initialized") |
|
|
|
|
|
if not firebase_auth: |
|
|
raise HTTPException(status_code=503, detail="Firebase Auth not available") |
|
|
|
|
|
try: |
|
|
|
|
|
user_record = firebase_auth.create_user( |
|
|
email=user_data.email, |
|
|
password=user_data.password, |
|
|
display_name=user_data.display_name, |
|
|
email_verified=False |
|
|
) |
|
|
|
|
|
|
|
|
custom_token = firebase_auth.create_custom_token(user_record.uid) |
|
|
|
|
|
logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid) |
|
|
|
|
|
return TokenResponse( |
|
|
id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token, |
|
|
token_type="Bearer", |
|
|
expires_in=3600, |
|
|
user={ |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name or "", |
|
|
"email_verified": user_record.email_verified |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
if "already exists" in error_msg.lower() or "email" in error_msg.lower(): |
|
|
raise HTTPException(status_code=400, detail="Email already registered") |
|
|
logger.error("Registration error: %s", error_msg) |
|
|
raise HTTPException(status_code=500, detail=f"Registration failed: {error_msg}") |
|
|
|
|
|
@app.post("/auth/login", response_model=TokenResponse) |
|
|
async def login_user(credentials: LoginRequest): |
|
|
""" |
|
|
Login with email and password. |
|
|
Uses Firebase REST API to authenticate and get ID token. |
|
|
""" |
|
|
if not firebase_admin: |
|
|
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available") |
|
|
|
|
|
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
|
|
raise HTTPException(status_code=503, detail="Firebase not initialized") |
|
|
|
|
|
if not firebase_auth: |
|
|
raise HTTPException(status_code=503, detail="Firebase Auth not available") |
|
|
|
|
|
|
|
|
from app.config import settings |
|
|
firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '') |
|
|
|
|
|
if not firebase_api_key: |
|
|
|
|
|
try: |
|
|
user_record = firebase_auth.get_user_by_email(credentials.email) |
|
|
custom_token = firebase_auth.create_custom_token(user_record.uid) |
|
|
|
|
|
logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid) |
|
|
|
|
|
return TokenResponse( |
|
|
id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token, |
|
|
token_type="Bearer", |
|
|
expires_in=3600, |
|
|
user={ |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name or "", |
|
|
"email_verified": user_record.email_verified |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
if "not found" in error_msg.lower(): |
|
|
raise HTTPException(status_code=401, detail="Invalid email or password") |
|
|
logger.error("Login error: %s", error_msg) |
|
|
raise HTTPException(status_code=500, detail=f"Login failed: {error_msg}") |
|
|
|
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}", |
|
|
json={ |
|
|
"email": credentials.email, |
|
|
"password": credentials.password, |
|
|
"returnSecureToken": True |
|
|
} |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
error_data = response.json() |
|
|
error_msg = error_data.get("error", {}).get("message", "Authentication failed") |
|
|
raise HTTPException(status_code=401, detail=error_msg) |
|
|
|
|
|
data = response.json() |
|
|
logger.info("User login successful: %s", credentials.email) |
|
|
|
|
|
|
|
|
user_record = firebase_auth.get_user(data["localId"]) |
|
|
|
|
|
return TokenResponse( |
|
|
id_token=data["idToken"], |
|
|
refresh_token=data.get("refreshToken"), |
|
|
expires_in=int(data.get("expiresIn", 3600)), |
|
|
token_type="Bearer", |
|
|
user={ |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name or "", |
|
|
"email_verified": user_record.email_verified |
|
|
} |
|
|
) |
|
|
except httpx.HTTPError as e: |
|
|
logger.error("HTTP error during login: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail="Authentication service unavailable") |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error("Login error: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
|
|
|
|
|
@app.get("/auth/me") |
|
|
async def get_current_user(request: Request, verified: bool = Depends(verify_request)): |
|
|
"""Get current authenticated user information""" |
|
|
if not firebase_admin: |
|
|
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available") |
|
|
|
|
|
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
|
|
raise HTTPException(status_code=503, detail="Firebase not initialized") |
|
|
|
|
|
if not firebase_auth: |
|
|
raise HTTPException(status_code=503, detail="Firebase Auth not available") |
|
|
|
|
|
|
|
|
if hasattr(request, 'state') and hasattr(request.state, 'user'): |
|
|
user_data = request.state.user |
|
|
uid = user_data.get("uid") |
|
|
|
|
|
try: |
|
|
user_record = firebase_auth.get_user(uid) |
|
|
return { |
|
|
"uid": user_record.uid, |
|
|
"email": user_record.email, |
|
|
"display_name": user_record.display_name or "", |
|
|
"email_verified": user_record.email_verified, |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error("Error getting user: %s", str(e)) |
|
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
|
|
|
raise HTTPException(status_code=401, detail="Not authenticated") |
|
|
|
|
|
@app.post("/auth/refresh") |
|
|
async def refresh_token(refresh_token_param: str = Body(..., embed=True)): |
|
|
"""Refresh Firebase ID token using refresh token""" |
|
|
from app.config import settings |
|
|
firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '') |
|
|
|
|
|
if not firebase_api_key: |
|
|
raise HTTPException(status_code=503, detail="Firebase API key not configured") |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post( |
|
|
f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}", |
|
|
json={ |
|
|
"grant_type": "refresh_token", |
|
|
"refresh_token": refresh_token_param |
|
|
} |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
error_data = response.json() |
|
|
error_msg = error_data.get("error", {}).get("message", "Token refresh failed") |
|
|
raise HTTPException(status_code=401, detail=error_msg) |
|
|
|
|
|
data = response.json() |
|
|
return { |
|
|
"id_token": data["id_token"], |
|
|
"refresh_token": data.get("refresh_token"), |
|
|
"expires_in": int(data.get("expires_in", 3600)), |
|
|
"token_type": "Bearer" |
|
|
} |
|
|
except httpx.HTTPError as e: |
|
|
logger.error("HTTP error during token refresh: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail="Token refresh service unavailable") |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error("Token refresh error: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
x_firebase_appcheck: str = Header(None), |
|
|
user_id: Optional[str] = Form(None), |
|
|
category_id: Optional[str] = Form(None), |
|
|
categoryId: Optional[str] = Form(None), |
|
|
): |
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
effective_user_id = _resolve_user_id(request, user_id) |
|
|
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
|
|
if effective_category_id: |
|
|
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
|
|
if not effective_category_id: |
|
|
effective_category_id = None |
|
|
|
|
|
if not file.content_type.startswith("image/"): |
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error="Invalid file type", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=400, detail="Invalid file type") |
|
|
|
|
|
image_id = f"{uuid.uuid4()}.jpg" |
|
|
file_path = os.path.join(UPLOAD_DIR, image_id) |
|
|
|
|
|
img_bytes = await file.read() |
|
|
file_size = len(img_bytes) |
|
|
|
|
|
with open(file_path, "wb") as f: |
|
|
f.write(img_bytes) |
|
|
|
|
|
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"image_id": image_id.replace(".jpg", ""), |
|
|
"file_url": f"{base_url}/uploads/{image_id}" |
|
|
} |
|
|
|
|
|
|
|
|
log_image_upload( |
|
|
image_id=image_id.replace(".jpg", ""), |
|
|
filename=file.filename or image_id, |
|
|
file_size=file_size, |
|
|
content_type=file.content_type or "image/jpeg", |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=200, |
|
|
request_data={"filename": file.filename, "content_type": file.content_type}, |
|
|
response_data=response_data, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_media_click( |
|
|
user_id=effective_user_id, |
|
|
category_id=effective_category_id, |
|
|
endpoint_path=str(request.url.path), |
|
|
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY, |
|
|
) |
|
|
|
|
|
return response_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
x_firebase_appcheck: str = Header(None), |
|
|
user_id: Optional[str] = Form(None), |
|
|
category_id: Optional[str] = Form(None), |
|
|
categoryId: Optional[str] = Form(None), |
|
|
model: Optional[str] = Form(None), |
|
|
): |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
verify_app_check_token(x_firebase_appcheck,required=False) |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
effective_user_id = _resolve_user_id(request, user_id) |
|
|
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
|
|
if effective_category_id: |
|
|
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
|
|
if not effective_category_id: |
|
|
effective_category_id = None |
|
|
|
|
|
|
|
|
|
|
|
if CCO_AVAILABLE: |
|
|
model_type = "cco" |
|
|
cco_model = "eccv16" |
|
|
model_type_for_log = "cco-eccv16" |
|
|
else: |
|
|
model_type = "gan" |
|
|
model_type_for_log = "gan" |
|
|
|
|
|
if model: |
|
|
model = model.strip().lower() |
|
|
if model == "gan": |
|
|
|
|
|
model_type = "gan" |
|
|
model_type_for_log = "gan" |
|
|
elif model == "cco" or model.startswith("cco-"): |
|
|
if not CCO_AVAILABLE: |
|
|
error_msg = "CCO models are not available" |
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error=error_msg, |
|
|
ip_address=ip_address |
|
|
) |
|
|
log_colorization( |
|
|
result_id=None, |
|
|
model_type="cco", |
|
|
processing_time=None, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="failed", |
|
|
error=error_msg |
|
|
) |
|
|
raise HTTPException(status_code=400, detail=error_msg) |
|
|
|
|
|
model_type = "cco" |
|
|
if model == "cco-eccv16": |
|
|
cco_model = "eccv16" |
|
|
model_type_for_log = "cco-eccv16" |
|
|
elif model == "cco-siggraph17": |
|
|
cco_model = "siggraph17" |
|
|
model_type_for_log = "cco-siggraph17" |
|
|
else: |
|
|
|
|
|
cco_model = "eccv16" |
|
|
model_type_for_log = "cco-eccv16" |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
if not file.content_type.startswith("image/"): |
|
|
error_msg = "Invalid file type" |
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error=error_msg, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_colorization( |
|
|
result_id=None, |
|
|
model_type=model_type_for_log, |
|
|
processing_time=None, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="failed", |
|
|
error=error_msg |
|
|
) |
|
|
raise HTTPException(status_code=400, detail=error_msg) |
|
|
|
|
|
try: |
|
|
img = Image.open(io.BytesIO(await file.read())) |
|
|
|
|
|
if img.mode != "RGB": |
|
|
img = img.convert("RGB") |
|
|
|
|
|
output_img = colorize_image(img, model_type=model_type, cco_model=cco_model) |
|
|
|
|
|
|
|
|
if output_img.mode != "RGB": |
|
|
output_img = output_img.convert("RGB") |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
result_id = f"{uuid.uuid4()}.png" |
|
|
output_path = os.path.join(RESULTS_DIR, result_id) |
|
|
|
|
|
output_img.save(output_path, "PNG") |
|
|
|
|
|
|
|
|
logger.info("Creating compressed version targeting 2-3MB file size...") |
|
|
compressed_img, temp_compressed_path = compress_image_to_target_size( |
|
|
output_img, |
|
|
target_size_mb=2.5, |
|
|
min_size_mb=2.0, |
|
|
max_size_mb=3.0 |
|
|
) |
|
|
|
|
|
|
|
|
compressed_filename = result_id.replace(".png", "_compressed.jpg") |
|
|
compressed_path = os.path.join(COMPRESSED_DIR, compressed_filename) |
|
|
|
|
|
|
|
|
if os.path.exists(temp_compressed_path): |
|
|
import shutil |
|
|
shutil.move(temp_compressed_path, compressed_path) |
|
|
else: |
|
|
compressed_img.save(compressed_path, "JPEG", quality=75, optimize=True) |
|
|
|
|
|
|
|
|
compressed_size_mb = os.path.getsize(compressed_path) / (1024 * 1024) |
|
|
original_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
logger.info("Original image size: %.2f MB, Compressed image size: %.2f MB", |
|
|
original_size_mb, compressed_size_mb) |
|
|
|
|
|
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
result_id_clean = result_id.replace(".png", "") |
|
|
|
|
|
|
|
|
caption = "colorize this image with vibrant natural colors, high quality" |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"result_id": result_id_clean, |
|
|
"download_url": f"{base_url}/results/{result_id}", |
|
|
"api_download_url": f"{base_url}/download/{result_id_clean}", |
|
|
"filename": result_id, |
|
|
"caption": caption, |
|
|
"Compressed_Image_URL": f"{base_url}/compressed/{compressed_filename}" |
|
|
} |
|
|
|
|
|
|
|
|
log_colorization( |
|
|
result_id=result_id_clean, |
|
|
model_type=model_type_for_log, |
|
|
processing_time=processing_time, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="success" |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=200, |
|
|
request_data={"filename": file.filename, "content_type": file.content_type, "model": model}, |
|
|
response_data=response_data, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_media_click( |
|
|
user_id=effective_user_id, |
|
|
category_id=effective_category_id, |
|
|
endpoint_path=str(request.url.path), |
|
|
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY, |
|
|
) |
|
|
|
|
|
return response_data |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.error("Error colorizing image: %s", error_msg) |
|
|
|
|
|
|
|
|
log_colorization( |
|
|
result_id=None, |
|
|
model_type=model_type_for_log, |
|
|
processing_time=None, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="failed", |
|
|
error=error_msg |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=500, |
|
|
error=error_msg, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/download/{file_id}") |
|
|
def download_result( |
|
|
request: Request, |
|
|
file_id: str, |
|
|
x_firebase_appcheck: str = Header(None) |
|
|
): |
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
|
|
|
filename_png = f"{file_id}.png" |
|
|
filename_jpg = f"{file_id}.jpg" |
|
|
path_png = os.path.join(RESULTS_DIR, filename_png) |
|
|
path_jpg = os.path.join(RESULTS_DIR, filename_jpg) |
|
|
|
|
|
|
|
|
if os.path.exists(path_png): |
|
|
filename = filename_png |
|
|
path = path_png |
|
|
media_type = "image/png" |
|
|
elif os.path.exists(path_jpg): |
|
|
filename = filename_jpg |
|
|
path = path_jpg |
|
|
media_type = "image/jpeg" |
|
|
else: |
|
|
filename = filename_png |
|
|
path = path_png |
|
|
media_type = "image/png" |
|
|
|
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/download/{file_id}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Result not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/download/{file_id}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"file_id": file_id}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type=media_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/results/{filename}") |
|
|
def get_result(request: Request, filename: str): |
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
path = os.path.join(RESULTS_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/results/{filename}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Result not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
|
|
|
if filename.lower().endswith('.png'): |
|
|
media_type = "image/png" |
|
|
elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): |
|
|
media_type = "image/jpeg" |
|
|
else: |
|
|
media_type = "image/png" |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/results/{filename}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"filename": filename}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type=media_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/uploads/{filename}") |
|
|
def get_upload(request: Request, filename: str): |
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
path = os.path.join(UPLOAD_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/uploads/{filename}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="File not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
|
|
|
|
|
if filename.lower().endswith('.png'): |
|
|
media_type = "image/png" |
|
|
elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): |
|
|
media_type = "image/jpeg" |
|
|
else: |
|
|
media_type = "image/jpeg" |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/uploads/{filename}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"filename": filename}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type=media_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/compressed/{filename}") |
|
|
def get_compressed(request: Request, filename: str): |
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
path = os.path.join(COMPRESSED_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/compressed/{filename}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Compressed file not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Compressed file not found") |
|
|
|
|
|
|
|
|
media_type = "image/jpeg" |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/compressed/{filename}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"filename": filename}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type=media_type) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint""" |
|
|
return { |
|
|
"success": True, |
|
|
"message": "Image Colorization API", |
|
|
"data": { |
|
|
"version": "1.0.0", |
|
|
"Product Name":"Beauty Camera - GlowCam AI Studio", |
|
|
"Released By" : "LogicGo Infotech" |
|
|
} |
|
|
} |