File size: 11,518 Bytes
f2638df
 
 
 
d47078a
f2638df
c3d84f4
d47078a
 
1f370be
f2638df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe3b199
f2638df
 
fe3b199
 
 
 
f2638df
fe3b199
 
f2638df
 
 
fe3b199
f2638df
bbabac5
f2638df
bbabac5
 
fe3b199
 
bbabac5
 
f2638df
 
fe3b199
f2638df
 
bbabac5
 
 
fe3b199
f2638df
fe3b199
f2638df
 
bbabac5
 
f2638df
 
fe3b199
f2638df
 
 
 
 
 
 
 
 
 
 
d47078a
 
fe3b199
d47078a
 
 
7bb1cd2
 
d47078a
7bb1cd2
d47078a
7bb1cd2
 
d47078a
 
7bb1cd2
 
d47078a
 
7bb1cd2
d47078a
 
 
 
7bb1cd2
d47078a
 
 
 
 
 
 
 
 
 
1f370be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import logging
from fastapi import APIRouter, HTTPException, Response, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from app.schemas import image as image_schemas
from app.schemas.image import ShoeGenerateRequest
from io import BytesIO
from typing import Optional
from app.services import image_service
from app.services.image_service import generate_shoe_images
from app.schemas.image import ShoeCheckResponse, UserInput, DressCheckResponse

logger = logging.getLogger(__name__)

router = APIRouter(
    prefix="/api/v1",
    tags=["Image Generation"]
)

# --- Helper function for streaming image responses ---
def create_image_streaming_response(image_data: BytesIO, media_type: str = "image/png"):
    """
    Creates a StreamingResponse for an image from BytesIO data.
    """
    if not image_data:
        raise HTTPException(status_code=404, detail="Image not found or could not be generated.")
    
    image_data.seek(0) # Ensure the pointer is at the beginning of the BytesIO object
    return StreamingResponse(image_data, media_type=media_type)

# --- Endpoints ---

@router.post("/enhance-prompt", response_model=image_schemas.EnhancePromptResponse)
async def enhance_prompt(request: image_schemas.EnhancePromptRequest):
    logger.info(f"Received request to /enhance-prompt for: {request.raw_prompt[:50]}...")
    try:
        enhanced_prompt = image_service.enhance_user_prompt(request.raw_prompt)
        logger.info("Successfully enhanced prompt.")
        return image_schemas.EnhancePromptResponse(
            raw_prompt=request.raw_prompt,
            enhanced_prompt=enhanced_prompt
        )
    except Exception as e:
        logger.error(f"Error in /enhance-prompt: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/generate-image", response_class=StreamingResponse) # Response class directly for image
async def generate_image(request: image_schemas.GenerateImageRequest):
    """
    Generates an image from either a raw or an enhanced prompt.
    Returns the image directly.
    """
    logger.info("Received request to /generate-image")
    try:
        if request.enhanced_prompt:
            image_prompt = request.enhanced_prompt
        elif request.raw_prompt:
            image_prompt = request.raw_prompt
        else:
            logger.warning("Bad request to /generate-image: No prompt provided.")
            raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.")
        
        # Service returns text and BytesIO
        generated_text, image_bytes_io = image_service.generate_image_from_text(image_prompt)
        
        if not image_bytes_io:
            logger.error("Image generation failed or returned no image.")
            # If text was returned, perhaps send that in a JSON error response
            if generated_text:
                 raise HTTPException(status_code=500, detail=f"Image generation failed. Model response: {generated_text}")
            raise HTTPException(status_code=500, detail="Image generation failed: No image data received.")

        logger.info("Successfully generated image. Streaming response.")
        return create_image_streaming_response(image_bytes_io)
    
    except HTTPException: # Re-raise HTTPExceptions directly
        raise
    except Exception as e:
        logger.error(f"Error in /generate-image: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

# --- update_image (MODIFIED) ---
@router.post("/update-image", response_class=StreamingResponse) # Response class is already correct
async def update_image(
    image: UploadFile = File(..., description="The image to update (PNG, JPG)"),
    text_instruction: str = Form(..., description="The text instruction for what to change.")
):
    """
    Updates an existing image using a text instruction.
    Returns the updated image directly.
    """
    logger.info("Received request to /update-image")

    # Check if the uploaded file is an image
    if not image.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="Image must be an image type (e.g., image/png, image/jpeg).")

    try:
        # Read the image bytes from the uploaded file
        image_bytes = await image.read()

        # Service returns text and BytesIO
        updated_text, updated_image_bytes_io = image_service.update_image_with_text(
            text_instruction=text_instruction,
            image_bytes=image_bytes  # Pass the raw bytes to the service
        )
        
        if not updated_image_bytes_io:
            logger.error("Image update failed or returned no image.")
            if updated_text:
                raise HTTPException(status_code=500, detail=f"Image update failed. Model response: {updated_text}")
            raise HTTPException(status_code=500, detail="Image update failed: No image data received.")
            
        logger.info("Successfully updated image. Streaming response.")
        return create_image_streaming_response(updated_image_bytes_io)
    
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error in /update-image: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))
    
    
# --- Corrected Virtual Try-On Endpoint with 3 Inputs ---
@router.post("/virtual-try-on")
async def virtual_try_on(
    dress_image: UploadFile = File(..., description="The dress image for try-on (PNG, JPG)"),
    person_image: UploadFile = File(..., description="The person image for try-on (PNG, JPG)"),
    shoes_image: Optional[UploadFile] = File(None, description="The shoes image for try-on (PNG, JPG)")
) -> Response:
    """
    Performs a virtual try-on using dress, person, and optional shoes.
    Returns the try-on image directly or a JSON response with a summary if no image.
    """
    logger.info("Received request to /virtual-try-on with image uploads.")

    # 1. Validate mandatory images
    if not dress_image.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="Dress image must be an image type.")
    if not person_image.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="Person image must be an image type.")

    # 2. Validate optional shoes
    # Note: We check 'shoes_image' is not None before checking content_type
    if shoes_image and not shoes_image.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="Shoes image must be an image type.")

    try:
        # 3. Read bytes
        dress_image_bytes = await dress_image.read()
        person_image_bytes = await person_image.read()
        
        shoes_image_bytes = None
        if shoes_image:
             shoes_image_bytes = await shoes_image.read()

        # 4. Call Service
        summary, try_on_image_bytes_io = image_service.virtual_try_on(
            dress_image_bytes=dress_image_bytes,
            person_image_bytes=person_image_bytes,
            shoes_image_bytes=shoes_image_bytes
        )
        
        # 5. Return Result
        if try_on_image_bytes_io:
            logger.info("Virtual try-on successful. Streaming image response.")
            return create_image_streaming_response(try_on_image_bytes_io)
        else:
            logger.warning("Virtual try-on returned no image. Sending JSON summary.")
            return JSONResponse(content=summary)
            
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error in /virtual-try-on: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

        
@router.post("/generate-shoe", response_class=StreamingResponse)
async def generate_shoe(request: ShoeGenerateRequest):
    """
    Generate shoe product images from a raw prompt.
    Streams the first generated image back.
    """
    logger.info("Received request to /generate-shoe: %s", request.prompt[:80])

    if not request.prompt:
        raise HTTPException(status_code=400, detail="`prompt` field is required.")

    try:
        # Call service with the raw prompt
        gen_text, images = generate_shoe_images(prompt=request.prompt)

        if images and len(images) > 0:
            # Stream the first image
            img_io: BytesIO = images[0]
            img_io.seek(0)
            return StreamingResponse(img_io, media_type="image/png")
        else:
            # Handle failure / text-only response
            detail = {"success": False, "notes": "No image generated."}
            if gen_text:
                detail["model_text"] = gen_text
            return JSONResponse(status_code=500, content=detail)

    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Error in /generate-shoe: %s", e)
        raise HTTPException(status_code=500, detail=str(e))

from app.core.config import settings

from groq import Groq
import json

client = Groq(api_key=settings.LANGCHAIN_API_KEY)

# --- Endpoint ---
@router.post("/check-shoe", response_model=ShoeCheckResponse)
async def check_shoe(user_input: UserInput):
    """
    Check if the user input text is related to shoe images using Groq model.
    Returns {"answer": "Yes"} or {"answer": "No"}.
    """
    try:
        # Define the JSON schema for response
        response_format = {
            "type": "json_schema",
            "json_schema": {
                "name": "shoe_check",
                "schema": ShoeCheckResponse.model_json_schema()
            }
        }

        # System + user prompt
        messages = [
            {"role": "system", "content": "You are a strict classifier that determines if text is about shoe images."},
            {"role": "user", "content": user_input.text}
        ]

        # Call Groq model
        response = client.chat.completions.create(
            model="moonshotai/kimi-k2-instruct-0905",
            messages=messages,
            response_format=response_format
        )

        # Parse model output
        result = ShoeCheckResponse.model_validate(json.loads(response.choices[0].message.content))
        return result

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


# --- Dress Check Endpoint ---
@router.post("/check-dress", response_model=DressCheckResponse)
async def check_dress(user_input: UserInput):
    """
    Check if the user input text is related to dress images using Groq model.
    Returns {"answer": "Yes"} or {"answer": "No"}.
    """
    try:
        response_format = {
            "type": "json_schema",
            "json_schema": {
                "name": "dress_check",
                "schema": DressCheckResponse.model_json_schema()
            }
        }

        # System + user prompt
        messages = [
            {"role": "system", "content": "You are a strict classifier that determines if text is about dress images."},
            {"role": "user", "content": user_input.text}
        ]

        # Call Groq model
        response = client.chat.completions.create(
            model="moonshotai/kimi-k2-instruct-0905",
            messages=messages,
            response_format=response_format
        )

        # Parse model output
        result = DressCheckResponse.model_validate(json.loads(response.choices[0].message.content))
        return result

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))