File size: 12,830 Bytes
d4a4da7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c9d10c
 
 
d4a4da7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""Image correction and regeneration endpoints."""

import os
import time
import uuid
import random
import logging
from datetime import datetime
from fastapi import APIRouter, HTTPException, Depends

from api.schemas import (
    ImageCorrectRequest,
    ImageCorrectResponse,
    ImageRegenerateRequest,
    ImageRegenerateResponse,
    ImageSelectionRequest,
)
from services.correction import correction_service
from services.database import db_service
from services.image import image_service
from services.auth_dependency import get_current_user
from config import settings

router = APIRouter(tags=["correction"])
api_logger = logging.getLogger("api")


@router.post("/api/correct", response_model=ImageCorrectResponse)
async def correct_image(
    request: ImageCorrectRequest,
    username: str = Depends(get_current_user),
):
    """
    Correct an image by analyzing it for spelling and visual issues,
    then regenerating a corrected version. Requires authentication.
    """
    api_start_time = time.time()
    api_logger.info("API: Correction request received | User: %s | Image ID: %s", username, request.image_id)

    try:
        image_url = request.image_url
        ad = None
        if request.image_id != "temp-id":
            ad = await db_service.get_ad_creative(request.image_id, username=username)
            if not ad:
                raise HTTPException(status_code=404, detail=f"Ad creative with ID {request.image_id} not found or access denied")
            if not image_url:
                image_url = ad.get("r2_url") or ad.get("image_url")

        if not image_url:
            raise HTTPException(status_code=400, detail="Image URL must be provided for images not in database, or found in database for provided ID")

        image_bytes = await image_service.load_image(
            image_id=request.image_id if request.image_id != "temp-id" else None,
            image_url=image_url,
            image_bytes=None,
            filepath=None,
        )
        if not image_bytes:
            raise HTTPException(status_code=404, detail="Image not found for analysis. Please ensure the URL is accessible.")

        original_prompt = ad.get("image_prompt") if ad else None
        result = await correction_service.correct_image(
            image_bytes=image_bytes,
            image_url=image_url,
            original_prompt=original_prompt,
            width=1024,
            height=1024,
            niche=ad.get("niche", "others") if ad else "others",
            user_instructions=request.user_instructions,
            auto_analyze=request.auto_analyze,
        )

        response_data = {
            "status": result["status"],
            "analysis": result.get("analysis"),
            "corrections": None,
            "corrected_image": None,
            "error": result.get("error"),
        }
        if result.get("corrections"):
            c = result["corrections"]
            response_data["corrections"] = {
                "spelling_corrections": c.get("spelling_corrections", []),
                "visual_corrections": c.get("visual_corrections", []),
                "corrected_prompt": c.get("corrected_prompt", ""),
            }
        if result.get("corrected_image"):
            ci = result["corrected_image"]
            response_data["corrected_image"] = {
                "filename": ci.get("filename"),
                "filepath": ci.get("filepath"),
                "image_url": ci.get("image_url"),
                "r2_url": ci.get("r2_url"),
                "model_used": ci.get("model_used"),
                "corrected_prompt": ci.get("corrected_prompt"),
            }

        if result.get("status") == "success" and result.get("_db_metadata") and ad:
            db_metadata = result["_db_metadata"]
            correction_metadata = {
                "is_corrected": True,
                "correction_date": datetime.utcnow().isoformat() + "Z",
            }
            for k, v in [
                ("original_image_url", ad.get("r2_url") or ad.get("image_url")),
                ("original_r2_url", ad.get("r2_url")),
                ("original_image_filename", ad.get("image_filename")),
                ("original_image_model", ad.get("image_model")),
                ("original_image_prompt", ad.get("image_prompt")),
            ]:
                if v:
                    correction_metadata[k] = v
            if result.get("corrections"):
                correction_metadata["corrections"] = result.get("corrections")
            update_kwargs = {
                "image_url": db_metadata.get("image_url"),
                "image_filename": db_metadata.get("filename"),
                "image_model": db_metadata.get("model_used"),
                "image_prompt": db_metadata.get("corrected_prompt"),
            }
            if db_metadata.get("r2_url"):
                update_kwargs["r2_url"] = db_metadata.get("r2_url")
            update_success = await db_service.update_ad_creative(
                ad_id=request.image_id,
                username=username,
                metadata=correction_metadata,
                **update_kwargs,
            )
            if update_success and response_data.get("corrected_image") is not None:
                response_data["corrected_image"]["ad_id"] = request.image_id

        api_logger.info("Correction request completed in %.2fs", time.time() - api_start_time)
        return response_data
    except HTTPException:
        raise
    except Exception as e:
        api_logger.exception("Correction failed")
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/api/regenerate", response_model=ImageRegenerateResponse)
async def regenerate_image(
    request: ImageRegenerateRequest,
    username: str = Depends(get_current_user),
):
    """
    Regenerate an image for an existing ad creative with an optional new model.
    If preview_only=True, returns preview without updating DB; use confirm to save.
    """
    api_start_time = time.time()
    try:
        ad = await db_service.get_ad_creative(request.image_id, username=username)
        if not ad:
            raise HTTPException(status_code=404, detail="Ad creative not found or access denied")
        image_prompt = ad.get("image_prompt")
        if not image_prompt:
            raise HTTPException(status_code=400, detail="No image prompt found for this ad creative.")
        model_to_use = request.image_model or ad.get("image_model") or settings.image_model
        seed = random.randint(1, 2147483647)
        image_bytes, model_used, generated_url = await image_service.generate(
            prompt=image_prompt,
            width=1024,
            height=1024,
            seed=seed,
            model_key=model_to_use,
        )
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        unique_id = uuid.uuid4().hex[:8]
        niche = ad.get("niche", "unknown").replace(" ", "_")
        filename = f"regen_{niche}_{timestamp}_{unique_id}.png"
        r2_url = None
        try:
            from services.r2_storage import get_r2_storage
            r2_storage = get_r2_storage()
            if r2_storage and image_bytes:
                r2_url = r2_storage.upload_image(image_bytes=image_bytes, filename=filename, niche=niche)
        except Exception as e:
            api_logger.warning("R2 upload failed: %s", e)
        local_path = None
        # Always save locally when we have bytes so /images/{filename} works as fallback
        # (avoids ERR_HTTP2_SERVER_REFUSED_STREAM when R2 presigned URL fails in browser)
        if image_bytes:
            local_path = os.path.join(settings.output_dir, filename)
            os.makedirs(os.path.dirname(local_path), exist_ok=True)
            with open(local_path, "wb") as f:
                f.write(image_bytes)
        original_image_url = ad.get("r2_url") or ad.get("image_url")
        new_image_url = r2_url or generated_url or f"/images/{filename}"

        if request.preview_only:
            return {
                "status": "success",
                "regenerated_image": {
                    "filename": filename,
                    "filepath": local_path,
                    "image_url": new_image_url,
                    "r2_url": r2_url,
                    "model_used": model_used,
                    "prompt_used": image_prompt,
                    "seed_used": seed,
                },
                "original_image_url": original_image_url,
                "original_preserved": True,
                "is_preview": True,
            }

        regeneration_metadata = {
            "is_regenerated": True,
            "regeneration_date": datetime.utcnow().isoformat() + "Z",
            "regeneration_seed": seed,
        }
        if original_image_url:
            regeneration_metadata["original_image_url"] = original_image_url
        for k, v in [("original_r2_url", ad.get("r2_url")), ("original_image_filename", ad.get("image_filename")), ("original_image_model", ad.get("image_model")), ("original_seed", ad.get("image_seed"))]:
            if v is not None:
                regeneration_metadata[k] = v
        update_kwargs = {"image_filename": filename, "image_model": model_used, "image_seed": seed}
        if r2_url:
            update_kwargs["image_url"] = update_kwargs["r2_url"] = r2_url
        elif generated_url:
            update_kwargs["image_url"] = generated_url
        elif local_path:
            update_kwargs["image_url"] = f"/images/{filename}"
        await db_service.update_ad_creative(
            ad_id=request.image_id,
            username=username,
            metadata=regeneration_metadata,
            **update_kwargs,
        )
        return {
            "status": "success",
            "regenerated_image": {
                "filename": filename,
                "filepath": local_path,
                "image_url": new_image_url,
                "r2_url": r2_url,
                "model_used": model_used,
                "prompt_used": image_prompt,
                "seed_used": seed,
            },
            "original_image_url": original_image_url,
            "original_preserved": True,
            "is_preview": False,
        }
    except HTTPException:
        raise
    except Exception as e:
        api_logger.exception("Regeneration failed")
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/api/regenerate/confirm")
async def confirm_image_selection(
    request: ImageSelectionRequest,
    username: str = Depends(get_current_user),
):
    """
    Confirm the user's image selection after regeneration preview.
    selection='new' updates the ad with the new image; selection='original' keeps original.
    """
    if request.selection not in ["new", "original"]:
        raise HTTPException(status_code=400, detail="Selection must be 'new' or 'original'")
    ad = await db_service.get_ad_creative(request.image_id, username=username)
    if not ad:
        raise HTTPException(status_code=404, detail="Ad creative not found or access denied")
    if request.selection == "original":
        return {"status": "success", "message": "Original image kept", "selection": "original"}
    if not request.new_image_url:
        raise HTTPException(status_code=400, detail="new_image_url is required when selection='new'")
    regeneration_metadata = {
        "is_regenerated": True,
        "regeneration_date": datetime.utcnow().isoformat() + "Z",
        "regeneration_seed": request.new_seed,
    }
    for k, v in [
        ("original_image_url", ad.get("r2_url") or ad.get("image_url")),
        ("original_r2_url", ad.get("r2_url")),
        ("original_image_filename", ad.get("image_filename")),
        ("original_image_model", ad.get("image_model")),
        ("original_seed", ad.get("image_seed")),
    ]:
        if v is not None:
            regeneration_metadata[k] = v
    update_kwargs = {}
    if request.new_filename:
        update_kwargs["image_filename"] = request.new_filename
    if request.new_model:
        update_kwargs["image_model"] = request.new_model
    if request.new_seed is not None:
        update_kwargs["image_seed"] = request.new_seed
    if request.new_r2_url:
        update_kwargs["image_url"] = update_kwargs["r2_url"] = request.new_r2_url
    elif request.new_image_url:
        update_kwargs["image_url"] = request.new_image_url
    update_success = await db_service.update_ad_creative(
        ad_id=request.image_id,
        username=username,
        metadata=regeneration_metadata,
        **update_kwargs,
    )
    if not update_success:
        raise HTTPException(status_code=500, detail="Failed to update ad with new image")
    return {"status": "success", "message": "New image saved", "selection": "new", "new_image_url": request.new_image_url}