"""Image correction and regeneration endpoints.""" import os import time import uuid import random import logging from datetime import datetime from fastapi import APIRouter, HTTPException, Depends from api.schemas import ( ImageCorrectRequest, ImageCorrectResponse, ImageRegenerateRequest, ImageRegenerateResponse, ImageSelectionRequest, ) from services.correction import correction_service from services.database import db_service from services.image import image_service from services.auth_dependency import get_current_user from config import settings router = APIRouter(tags=["correction"]) api_logger = logging.getLogger("api") @router.post("/api/correct", response_model=ImageCorrectResponse) async def correct_image( request: ImageCorrectRequest, username: str = Depends(get_current_user), ): """ Correct an image by analyzing it for spelling and visual issues, then regenerating a corrected version. Requires authentication. """ api_start_time = time.time() api_logger.info("API: Correction request received | User: %s | Image ID: %s", username, request.image_id) try: image_url = request.image_url ad = None if request.image_id != "temp-id": ad = await db_service.get_ad_creative(request.image_id, username=username) if not ad: raise HTTPException(status_code=404, detail=f"Ad creative with ID {request.image_id} not found or access denied") if not image_url: image_url = ad.get("r2_url") or ad.get("image_url") if not image_url: raise HTTPException(status_code=400, detail="Image URL must be provided for images not in database, or found in database for provided ID") image_bytes = await image_service.load_image( image_id=request.image_id if request.image_id != "temp-id" else None, image_url=image_url, image_bytes=None, filepath=None, ) if not image_bytes: raise HTTPException(status_code=404, detail="Image not found for analysis. Please ensure the URL is accessible.") original_prompt = ad.get("image_prompt") if ad else None result = await correction_service.correct_image( image_bytes=image_bytes, image_url=image_url, original_prompt=original_prompt, width=1024, height=1024, niche=ad.get("niche", "others") if ad else "others", user_instructions=request.user_instructions, auto_analyze=request.auto_analyze, ) response_data = { "status": result["status"], "analysis": result.get("analysis"), "corrections": None, "corrected_image": None, "error": result.get("error"), } if result.get("corrections"): c = result["corrections"] response_data["corrections"] = { "spelling_corrections": c.get("spelling_corrections", []), "visual_corrections": c.get("visual_corrections", []), "corrected_prompt": c.get("corrected_prompt", ""), } if result.get("corrected_image"): ci = result["corrected_image"] response_data["corrected_image"] = { "filename": ci.get("filename"), "filepath": ci.get("filepath"), "image_url": ci.get("image_url"), "r2_url": ci.get("r2_url"), "model_used": ci.get("model_used"), "corrected_prompt": ci.get("corrected_prompt"), } if result.get("status") == "success" and result.get("_db_metadata") and ad: db_metadata = result["_db_metadata"] correction_metadata = { "is_corrected": True, "correction_date": datetime.utcnow().isoformat() + "Z", } for k, v in [ ("original_image_url", ad.get("r2_url") or ad.get("image_url")), ("original_r2_url", ad.get("r2_url")), ("original_image_filename", ad.get("image_filename")), ("original_image_model", ad.get("image_model")), ("original_image_prompt", ad.get("image_prompt")), ]: if v: correction_metadata[k] = v if result.get("corrections"): correction_metadata["corrections"] = result.get("corrections") update_kwargs = { "image_url": db_metadata.get("image_url"), "image_filename": db_metadata.get("filename"), "image_model": db_metadata.get("model_used"), "image_prompt": db_metadata.get("corrected_prompt"), } if db_metadata.get("r2_url"): update_kwargs["r2_url"] = db_metadata.get("r2_url") update_success = await db_service.update_ad_creative( ad_id=request.image_id, username=username, metadata=correction_metadata, **update_kwargs, ) if update_success and response_data.get("corrected_image") is not None: response_data["corrected_image"]["ad_id"] = request.image_id api_logger.info("Correction request completed in %.2fs", time.time() - api_start_time) return response_data except HTTPException: raise except Exception as e: api_logger.exception("Correction failed") raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/regenerate", response_model=ImageRegenerateResponse) async def regenerate_image( request: ImageRegenerateRequest, username: str = Depends(get_current_user), ): """ Regenerate an image for an existing ad creative with an optional new model. If preview_only=True, returns preview without updating DB; use confirm to save. """ api_start_time = time.time() try: ad = await db_service.get_ad_creative(request.image_id, username=username) if not ad: raise HTTPException(status_code=404, detail="Ad creative not found or access denied") image_prompt = ad.get("image_prompt") if not image_prompt: raise HTTPException(status_code=400, detail="No image prompt found for this ad creative.") model_to_use = request.image_model or ad.get("image_model") or settings.image_model seed = random.randint(1, 2147483647) image_bytes, model_used, generated_url = await image_service.generate( prompt=image_prompt, width=1024, height=1024, seed=seed, model_key=model_to_use, ) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") unique_id = uuid.uuid4().hex[:8] niche = ad.get("niche", "unknown").replace(" ", "_") filename = f"regen_{niche}_{timestamp}_{unique_id}.png" r2_url = None try: from services.r2_storage import get_r2_storage r2_storage = get_r2_storage() if r2_storage and image_bytes: r2_url = r2_storage.upload_image(image_bytes=image_bytes, filename=filename, niche=niche) except Exception as e: api_logger.warning("R2 upload failed: %s", e) local_path = None # Always save locally when we have bytes so /images/{filename} works as fallback # (avoids ERR_HTTP2_SERVER_REFUSED_STREAM when R2 presigned URL fails in browser) if image_bytes: local_path = os.path.join(settings.output_dir, filename) os.makedirs(os.path.dirname(local_path), exist_ok=True) with open(local_path, "wb") as f: f.write(image_bytes) original_image_url = ad.get("r2_url") or ad.get("image_url") new_image_url = r2_url or generated_url or f"/images/{filename}" if request.preview_only: return { "status": "success", "regenerated_image": { "filename": filename, "filepath": local_path, "image_url": new_image_url, "r2_url": r2_url, "model_used": model_used, "prompt_used": image_prompt, "seed_used": seed, }, "original_image_url": original_image_url, "original_preserved": True, "is_preview": True, } regeneration_metadata = { "is_regenerated": True, "regeneration_date": datetime.utcnow().isoformat() + "Z", "regeneration_seed": seed, } if original_image_url: regeneration_metadata["original_image_url"] = original_image_url for k, v in [("original_r2_url", ad.get("r2_url")), ("original_image_filename", ad.get("image_filename")), ("original_image_model", ad.get("image_model")), ("original_seed", ad.get("image_seed"))]: if v is not None: regeneration_metadata[k] = v update_kwargs = {"image_filename": filename, "image_model": model_used, "image_seed": seed} if r2_url: update_kwargs["image_url"] = update_kwargs["r2_url"] = r2_url elif generated_url: update_kwargs["image_url"] = generated_url elif local_path: update_kwargs["image_url"] = f"/images/{filename}" await db_service.update_ad_creative( ad_id=request.image_id, username=username, metadata=regeneration_metadata, **update_kwargs, ) return { "status": "success", "regenerated_image": { "filename": filename, "filepath": local_path, "image_url": new_image_url, "r2_url": r2_url, "model_used": model_used, "prompt_used": image_prompt, "seed_used": seed, }, "original_image_url": original_image_url, "original_preserved": True, "is_preview": False, } except HTTPException: raise except Exception as e: api_logger.exception("Regeneration failed") raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/regenerate/confirm") async def confirm_image_selection( request: ImageSelectionRequest, username: str = Depends(get_current_user), ): """ Confirm the user's image selection after regeneration preview. selection='new' updates the ad with the new image; selection='original' keeps original. """ if request.selection not in ["new", "original"]: raise HTTPException(status_code=400, detail="Selection must be 'new' or 'original'") ad = await db_service.get_ad_creative(request.image_id, username=username) if not ad: raise HTTPException(status_code=404, detail="Ad creative not found or access denied") if request.selection == "original": return {"status": "success", "message": "Original image kept", "selection": "original"} if not request.new_image_url: raise HTTPException(status_code=400, detail="new_image_url is required when selection='new'") regeneration_metadata = { "is_regenerated": True, "regeneration_date": datetime.utcnow().isoformat() + "Z", "regeneration_seed": request.new_seed, } for k, v in [ ("original_image_url", ad.get("r2_url") or ad.get("image_url")), ("original_r2_url", ad.get("r2_url")), ("original_image_filename", ad.get("image_filename")), ("original_image_model", ad.get("image_model")), ("original_seed", ad.get("image_seed")), ]: if v is not None: regeneration_metadata[k] = v update_kwargs = {} if request.new_filename: update_kwargs["image_filename"] = request.new_filename if request.new_model: update_kwargs["image_model"] = request.new_model if request.new_seed is not None: update_kwargs["image_seed"] = request.new_seed if request.new_r2_url: update_kwargs["image_url"] = update_kwargs["r2_url"] = request.new_r2_url elif request.new_image_url: update_kwargs["image_url"] = request.new_image_url update_success = await db_service.update_ad_creative( ad_id=request.image_id, username=username, metadata=regeneration_metadata, **update_kwargs, ) if not update_success: raise HTTPException(status_code=500, detail="Failed to update ad with new image") return {"status": "success", "message": "New image saved", "selection": "new", "new_image_url": request.new_image_url}