|
|
import os |
|
|
import io |
|
|
import json |
|
|
import traceback |
|
|
from datetime import datetime,timedelta |
|
|
from typing import Optional |
|
|
import time |
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from pymongo import MongoClient |
|
|
import gridfs |
|
|
from bson.objectid import ObjectId |
|
|
from PIL import Image |
|
|
from fastapi.concurrency import run_in_threadpool |
|
|
import shutil |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, auth |
|
|
from PIL import Image |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
firebase_config_json = os.getenv("firebase_config") |
|
|
if not firebase_config_json: |
|
|
raise RuntimeError("β Missing Firebase config in environment variable 'firebase_config'") |
|
|
|
|
|
try: |
|
|
firebase_creds_dict = json.loads(firebase_config_json) |
|
|
cred = credentials.Certificate(firebase_creds_dict) |
|
|
firebase_admin.initialize_app(cred) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
raise RuntimeError("HF_TOKEN not set in environment variables") |
|
|
|
|
|
hf_client = InferenceClient(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MONGODB_URI=os.getenv("MONGODB_URI") |
|
|
DB_NAME = "polaroid_db" |
|
|
|
|
|
mongo = MongoClient(MONGODB_URI) |
|
|
db = mongo[DB_NAME] |
|
|
fs = gridfs.GridFS(db) |
|
|
logs_collection = db["logs"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Qwen Image Edit API with Firebase Auth") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def verify_firebase_token(request: Request): |
|
|
"""Middleware-like dependency to verify Firebase JWT from Authorization header.""" |
|
|
auth_header = request.headers.get("Authorization") |
|
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
|
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") |
|
|
|
|
|
id_token = auth_header.split("Bearer ")[1] |
|
|
try: |
|
|
decoded_token = auth.verify_id_token(id_token) |
|
|
request.state.user = decoded_token |
|
|
return decoded_token |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
db: str |
|
|
model: str |
|
|
|
|
|
|
|
|
def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: |
|
|
""" |
|
|
Resize image to fit within max_size while keeping aspect ratio. |
|
|
""" |
|
|
if img.width > max_size[0] or img.height > max_size[1]: |
|
|
img.thumbnail(max_size, Image.ANTIALIAS) |
|
|
return img |
|
|
|
|
|
def prepare_image(file_bytes: bytes) -> Image.Image: |
|
|
""" |
|
|
Open image and resize if larger than 1024x1024 |
|
|
""" |
|
|
img = Image.open(io.BytesIO(file_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
if img.width < 200 or img.height < 200: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Image size is below 200x200 pixels. Please upload a larger image." |
|
|
) |
|
|
img = Image.open(io.BytesIO(file_bytes)).convert("RGB") |
|
|
img = resize_image_if_needed(img, max_size=(1024, 1024)) |
|
|
return img |
|
|
|
|
|
MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 |
|
|
def compress_pil_image_to_2mb( |
|
|
pil_img: Image.Image, |
|
|
max_dim: int = 1280 |
|
|
) -> bytes: |
|
|
""" |
|
|
Resize + compress PIL image to <= 2MB. |
|
|
Returns JPEG bytes. |
|
|
""" |
|
|
img = pil_img.convert("RGB") |
|
|
|
|
|
|
|
|
img.thumbnail((max_dim, max_dim), Image.LANCZOS) |
|
|
|
|
|
quality = 85 |
|
|
buffer = io.BytesIO() |
|
|
|
|
|
while quality >= 40: |
|
|
buffer.seek(0) |
|
|
buffer.truncate() |
|
|
|
|
|
img.save( |
|
|
buffer, |
|
|
format="JPEG", |
|
|
quality=quality, |
|
|
optimize=True, |
|
|
progressive=True |
|
|
) |
|
|
|
|
|
if buffer.tell() <= MAX_COMPRESSED_SIZE: |
|
|
break |
|
|
|
|
|
quality -= 5 |
|
|
|
|
|
return buffer.getvalue() |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
def health(): |
|
|
"""Public health check""" |
|
|
mongo.admin.command("ping") |
|
|
return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate( |
|
|
prompt: str = Form(...), |
|
|
image1: UploadFile = File(...), |
|
|
image2: Optional[UploadFile] = File(None), |
|
|
user_id: Optional[str] = Form(None), |
|
|
category_id: Optional[str] = Form(None), |
|
|
user=Depends(verify_firebase_token) |
|
|
): |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
img1_bytes = await image1.read() |
|
|
pil_img1 = prepare_image(img1_bytes) |
|
|
input1_id = fs.put( |
|
|
img1_bytes, |
|
|
filename=image1.filename, |
|
|
contentType=image1.content_type, |
|
|
metadata={"role": "input"} |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(400, f"Failed to read first image: {e}") |
|
|
|
|
|
img2_bytes = None |
|
|
input2_id = None |
|
|
pil_img2 = None |
|
|
|
|
|
if image2: |
|
|
try: |
|
|
img2_bytes = await image2.read() |
|
|
pil_img2 = prepare_image(img2_bytes) |
|
|
input2_id = fs.put( |
|
|
img2_bytes, |
|
|
filename=image2.filename, |
|
|
contentType=image2.content_type, |
|
|
metadata={"role": "input"} |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(400, f"Failed to read second image: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pil_img2: |
|
|
total_width = pil_img1.width + pil_img2.width |
|
|
max_height = max(pil_img1.height, pil_img2.height) |
|
|
combined_img = Image.new("RGB", (total_width, max_height)) |
|
|
combined_img.paste(pil_img1, (0, 0)) |
|
|
combined_img.paste(pil_img2, (pil_img1.width, 0)) |
|
|
else: |
|
|
combined_img = pil_img1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_id and category_id: |
|
|
try: |
|
|
admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) |
|
|
admin_db = admin_client["adminPanel"] |
|
|
|
|
|
categories_col = admin_db.categories |
|
|
media_clicks_col = admin_db.media_clicks |
|
|
|
|
|
|
|
|
user_oid = ObjectId(user_id) |
|
|
category_oid = ObjectId(category_id) |
|
|
|
|
|
|
|
|
category_doc = categories_col.find_one({"_id": category_oid}) |
|
|
if not category_doc: |
|
|
raise HTTPException(400, f"Invalid category_id: {category_id}") |
|
|
|
|
|
now = datetime.utcnow() |
|
|
|
|
|
|
|
|
today_date = datetime(now.year, now.month, now.day) |
|
|
yesterday_date = today_date - timedelta(days=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
media_clicks_col.update_one( |
|
|
{"userId": user_oid}, |
|
|
{ |
|
|
"$setOnInsert": { |
|
|
"createdAt": now, |
|
|
"ai_edit_daily_count": [] |
|
|
}, |
|
|
"$set": { |
|
|
"ai_edit_last_date": now, |
|
|
"updatedAt": now |
|
|
}, |
|
|
"$inc": { |
|
|
"ai_edit_complete": 1 |
|
|
} |
|
|
}, |
|
|
upsert=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
now = datetime.utcnow() |
|
|
today_date = datetime(now.year, now.month, now.day) |
|
|
|
|
|
doc = media_clicks_col.find_one( |
|
|
{"userId": user_oid}, |
|
|
{"ai_edit_daily_count": 1} |
|
|
) |
|
|
|
|
|
daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] |
|
|
|
|
|
|
|
|
daily_map = {} |
|
|
for entry in daily_entries: |
|
|
d = entry["date"] |
|
|
d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d |
|
|
daily_map[d] = entry["count"] |
|
|
|
|
|
|
|
|
last_date = max(daily_map.keys()) if daily_map else today_date |
|
|
|
|
|
|
|
|
next_day = last_date + timedelta(days=1) |
|
|
while next_day < today_date: |
|
|
daily_map.setdefault(next_day, 0) |
|
|
next_day += timedelta(days=1) |
|
|
|
|
|
|
|
|
daily_map[today_date] = 1 |
|
|
|
|
|
|
|
|
final_daily_entries = [ |
|
|
{"date": d, "count": daily_map[d]} |
|
|
for d in sorted(daily_map.keys()) |
|
|
] |
|
|
|
|
|
|
|
|
final_daily_entries = final_daily_entries[-32:] |
|
|
|
|
|
|
|
|
media_clicks_col.update_one( |
|
|
{"userId": user_oid}, |
|
|
{ |
|
|
"$set": { |
|
|
"ai_edit_daily_count": final_daily_entries, |
|
|
"ai_edit_last_date": now, |
|
|
"updatedAt": now |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
update_res = media_clicks_col.update_one( |
|
|
{"userId": user_oid, "categories.categoryId": category_oid}, |
|
|
{ |
|
|
"$set": { |
|
|
"updatedAt": now, |
|
|
"categories.$.lastClickedAt": now |
|
|
}, |
|
|
"$inc": { |
|
|
"categories.$.click_count": 1 |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if update_res.matched_count == 0: |
|
|
media_clicks_col.update_one( |
|
|
{"userId": user_oid}, |
|
|
{ |
|
|
"$set": {"updatedAt": now}, |
|
|
"$push": { |
|
|
"categories": { |
|
|
"categoryId": category_oid, |
|
|
"click_count": 1, |
|
|
"lastClickedAt": now |
|
|
} |
|
|
} |
|
|
}, |
|
|
upsert=True |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print("CATEGORY_LOG_ERROR:", e) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
pil_output = hf_client.image_to_image( |
|
|
image=combined_img, |
|
|
prompt=prompt, |
|
|
model="Qwen/Qwen-Image-Edit" |
|
|
) |
|
|
except Exception as e: |
|
|
response_time_ms = round((time.time() - start_time) * 1000) |
|
|
logs_collection.insert_one({ |
|
|
"timestamp": datetime.utcnow(), |
|
|
"status": "failure", |
|
|
"input1_id": str(input1_id), |
|
|
"input2_id": str(input2_id) if input2_id else None, |
|
|
"prompt": prompt, |
|
|
"user_email": user.get("email"), |
|
|
"error": str(e), |
|
|
"response_time_ms": response_time_ms |
|
|
}) |
|
|
raise HTTPException(500, f"Inference failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_buf = io.BytesIO() |
|
|
pil_output.save(out_buf, format="PNG") |
|
|
out_bytes = out_buf.getvalue() |
|
|
|
|
|
out_id = fs.put( |
|
|
out_bytes, |
|
|
filename=f"result_{input1_id}.png", |
|
|
contentType="image/png", |
|
|
metadata={ |
|
|
"role": "output", |
|
|
"prompt": prompt, |
|
|
"input1_id": str(input1_id), |
|
|
"input2_id": str(input2_id) if input2_id else None, |
|
|
"user_email": user.get("email"), |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
compressed_bytes = compress_pil_image_to_2mb( |
|
|
pil_output, |
|
|
max_dim=1280 |
|
|
) |
|
|
|
|
|
compressed_id = fs.put( |
|
|
compressed_bytes, |
|
|
filename=f"result_{input1_id}_compressed.jpg", |
|
|
contentType="image/jpeg", |
|
|
metadata={ |
|
|
"role": "output_compressed", |
|
|
"original_output_id": str(out_id), |
|
|
"prompt": prompt, |
|
|
"user_email": user.get("email") |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
response_time_ms = round((time.time() - start_time) * 1000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs_collection.insert_one({ |
|
|
"timestamp": datetime.utcnow(), |
|
|
"status": "success", |
|
|
"input1_id": str(input1_id), |
|
|
"input2_id": str(input2_id) if input2_id else None, |
|
|
"output_id": str(out_id), |
|
|
"prompt": prompt, |
|
|
"user_email": user.get("email"), |
|
|
"response_time_ms": response_time_ms |
|
|
}) |
|
|
|
|
|
return JSONResponse({ |
|
|
"output_image_id": str(out_id), |
|
|
"user": user.get("email"), |
|
|
"response_time_ms": response_time_ms, |
|
|
"Compressed_Image_URL": ( |
|
|
f"https://logicgoinfotechspaces-polaroidimage.hf.space/image/{compressed_id}" |
|
|
) |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/image/{image_id}") |
|
|
def get_image(image_id: str, download: Optional[bool] = False): |
|
|
"""Retrieve stored image by ID (no authentication required).""" |
|
|
try: |
|
|
oid = ObjectId(image_id) |
|
|
grid_out = fs.get(oid) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=404, detail="Image not found") |
|
|
|
|
|
def iterfile(): |
|
|
yield grid_out.read() |
|
|
|
|
|
headers = {} |
|
|
if download: |
|
|
headers["Content-Disposition"] = f'attachment; filename="{grid_out.filename}"' |
|
|
|
|
|
return StreamingResponse( |
|
|
iterfile(), |
|
|
media_type=grid_out.content_type or "application/octet-stream", |
|
|
headers=headers |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
|
|
|
|
|
|