import logging from fastapi import APIRouter, HTTPException, Response, UploadFile, File, Form from fastapi.responses import StreamingResponse, JSONResponse from app.schemas import image as image_schemas from app.schemas.image import ShoeGenerateRequest from io import BytesIO from typing import Optional from app.services import image_service from app.services.image_service import generate_shoe_images from app.schemas.image import ShoeCheckResponse, UserInput, DressCheckResponse logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/v1", tags=["Image Generation"] ) # --- Helper function for streaming image responses --- def create_image_streaming_response(image_data: BytesIO, media_type: str = "image/png"): """ Creates a StreamingResponse for an image from BytesIO data. """ if not image_data: raise HTTPException(status_code=404, detail="Image not found or could not be generated.") image_data.seek(0) # Ensure the pointer is at the beginning of the BytesIO object return StreamingResponse(image_data, media_type=media_type) # --- Endpoints --- @router.post("/enhance-prompt", response_model=image_schemas.EnhancePromptResponse) async def enhance_prompt(request: image_schemas.EnhancePromptRequest): logger.info(f"Received request to /enhance-prompt for: {request.raw_prompt[:50]}...") try: enhanced_prompt = image_service.enhance_user_prompt(request.raw_prompt) logger.info("Successfully enhanced prompt.") return image_schemas.EnhancePromptResponse( raw_prompt=request.raw_prompt, enhanced_prompt=enhanced_prompt ) except Exception as e: logger.error(f"Error in /enhance-prompt: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate-image", response_class=StreamingResponse) # Response class directly for image async def generate_image(request: image_schemas.GenerateImageRequest): """ Generates an image from either a raw or an enhanced prompt. Returns the image directly. """ logger.info("Received request to /generate-image") try: if request.enhanced_prompt: image_prompt = request.enhanced_prompt elif request.raw_prompt: image_prompt = request.raw_prompt else: logger.warning("Bad request to /generate-image: No prompt provided.") raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.") # Service returns text and BytesIO generated_text, image_bytes_io = image_service.generate_image_from_text(image_prompt) if not image_bytes_io: logger.error("Image generation failed or returned no image.") # If text was returned, perhaps send that in a JSON error response if generated_text: raise HTTPException(status_code=500, detail=f"Image generation failed. Model response: {generated_text}") raise HTTPException(status_code=500, detail="Image generation failed: No image data received.") logger.info("Successfully generated image. Streaming response.") return create_image_streaming_response(image_bytes_io) except HTTPException: # Re-raise HTTPExceptions directly raise except Exception as e: logger.error(f"Error in /generate-image: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # --- update_image (MODIFIED) --- @router.post("/update-image", response_class=StreamingResponse) # Response class is already correct async def update_image( image: UploadFile = File(..., description="The image to update (PNG, JPG)"), text_instruction: str = Form(..., description="The text instruction for what to change.") ): """ Updates an existing image using a text instruction. Returns the updated image directly. """ logger.info("Received request to /update-image") # Check if the uploaded file is an image if not image.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Image must be an image type (e.g., image/png, image/jpeg).") try: # Read the image bytes from the uploaded file image_bytes = await image.read() # Service returns text and BytesIO updated_text, updated_image_bytes_io = image_service.update_image_with_text( text_instruction=text_instruction, image_bytes=image_bytes # Pass the raw bytes to the service ) if not updated_image_bytes_io: logger.error("Image update failed or returned no image.") if updated_text: raise HTTPException(status_code=500, detail=f"Image update failed. Model response: {updated_text}") raise HTTPException(status_code=500, detail="Image update failed: No image data received.") logger.info("Successfully updated image. Streaming response.") return create_image_streaming_response(updated_image_bytes_io) except HTTPException: raise except Exception as e: logger.error(f"Error in /update-image: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # --- Corrected Virtual Try-On Endpoint with 3 Inputs --- @router.post("/virtual-try-on") async def virtual_try_on( dress_image: UploadFile = File(..., description="The dress image for try-on (PNG, JPG)"), person_image: UploadFile = File(..., description="The person image for try-on (PNG, JPG)"), shoes_image: Optional[UploadFile] = File(None, description="The shoes image for try-on (PNG, JPG)") ) -> Response: """ Performs a virtual try-on using dress, person, and optional shoes. Returns the try-on image directly or a JSON response with a summary if no image. """ logger.info("Received request to /virtual-try-on with image uploads.") # 1. Validate mandatory images if not dress_image.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Dress image must be an image type.") if not person_image.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Person image must be an image type.") # 2. Validate optional shoes # Note: We check 'shoes_image' is not None before checking content_type if shoes_image and not shoes_image.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Shoes image must be an image type.") try: # 3. Read bytes dress_image_bytes = await dress_image.read() person_image_bytes = await person_image.read() shoes_image_bytes = None if shoes_image: shoes_image_bytes = await shoes_image.read() # 4. Call Service summary, try_on_image_bytes_io = image_service.virtual_try_on( dress_image_bytes=dress_image_bytes, person_image_bytes=person_image_bytes, shoes_image_bytes=shoes_image_bytes ) # 5. Return Result if try_on_image_bytes_io: logger.info("Virtual try-on successful. Streaming image response.") return create_image_streaming_response(try_on_image_bytes_io) else: logger.warning("Virtual try-on returned no image. Sending JSON summary.") return JSONResponse(content=summary) except HTTPException: raise except Exception as e: logger.error(f"Error in /virtual-try-on: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/generate-shoe", response_class=StreamingResponse) async def generate_shoe(request: ShoeGenerateRequest): """ Generate shoe product images from a raw prompt. Streams the first generated image back. """ logger.info("Received request to /generate-shoe: %s", request.prompt[:80]) if not request.prompt: raise HTTPException(status_code=400, detail="`prompt` field is required.") try: # Call service with the raw prompt gen_text, images = generate_shoe_images(prompt=request.prompt) if images and len(images) > 0: # Stream the first image img_io: BytesIO = images[0] img_io.seek(0) return StreamingResponse(img_io, media_type="image/png") else: # Handle failure / text-only response detail = {"success": False, "notes": "No image generated."} if gen_text: detail["model_text"] = gen_text return JSONResponse(status_code=500, content=detail) except HTTPException: raise except Exception as e: logger.exception("Error in /generate-shoe: %s", e) raise HTTPException(status_code=500, detail=str(e)) from app.core.config import settings from groq import Groq import json client = Groq(api_key=settings.LANGCHAIN_API_KEY) # --- Endpoint --- @router.post("/check-shoe", response_model=ShoeCheckResponse) async def check_shoe(user_input: UserInput): """ Check if the user input text is related to shoe images using Groq model. Returns {"answer": "Yes"} or {"answer": "No"}. """ try: # Define the JSON schema for response response_format = { "type": "json_schema", "json_schema": { "name": "shoe_check", "schema": ShoeCheckResponse.model_json_schema() } } # System + user prompt messages = [ {"role": "system", "content": "You are a strict classifier that determines if text is about shoe images."}, {"role": "user", "content": user_input.text} ] # Call Groq model response = client.chat.completions.create( model="moonshotai/kimi-k2-instruct-0905", messages=messages, response_format=response_format ) # Parse model output result = ShoeCheckResponse.model_validate(json.loads(response.choices[0].message.content)) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --- Dress Check Endpoint --- @router.post("/check-dress", response_model=DressCheckResponse) async def check_dress(user_input: UserInput): """ Check if the user input text is related to dress images using Groq model. Returns {"answer": "Yes"} or {"answer": "No"}. """ try: response_format = { "type": "json_schema", "json_schema": { "name": "dress_check", "schema": DressCheckResponse.model_json_schema() } } # System + user prompt messages = [ {"role": "system", "content": "You are a strict classifier that determines if text is about dress images."}, {"role": "user", "content": user_input.text} ] # Call Groq model response = client.chat.completions.create( model="moonshotai/kimi-k2-instruct-0905", messages=messages, response_format=response_format ) # Parse model output result = DressCheckResponse.model_validate(json.loads(response.choices[0].message.content)) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e))