Spaces:
Runtime error
Runtime error
| import logging | |
| from fastapi import APIRouter, HTTPException, Response, UploadFile, File, Form | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from app.schemas import image as image_schemas | |
| from app.schemas.image import ShoeGenerateRequest | |
| from io import BytesIO | |
| from typing import Optional | |
| from app.services import image_service | |
| from app.services.image_service import generate_shoe_images | |
| from app.schemas.image import ShoeCheckResponse, UserInput, DressCheckResponse | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter( | |
| prefix="/api/v1", | |
| tags=["Image Generation"] | |
| ) | |
| # --- Helper function for streaming image responses --- | |
| def create_image_streaming_response(image_data: BytesIO, media_type: str = "image/png"): | |
| """ | |
| Creates a StreamingResponse for an image from BytesIO data. | |
| """ | |
| if not image_data: | |
| raise HTTPException(status_code=404, detail="Image not found or could not be generated.") | |
| image_data.seek(0) # Ensure the pointer is at the beginning of the BytesIO object | |
| return StreamingResponse(image_data, media_type=media_type) | |
| # --- Endpoints --- | |
| async def enhance_prompt(request: image_schemas.EnhancePromptRequest): | |
| logger.info(f"Received request to /enhance-prompt for: {request.raw_prompt[:50]}...") | |
| try: | |
| enhanced_prompt = image_service.enhance_user_prompt(request.raw_prompt) | |
| logger.info("Successfully enhanced prompt.") | |
| return image_schemas.EnhancePromptResponse( | |
| raw_prompt=request.raw_prompt, | |
| enhanced_prompt=enhanced_prompt | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in /enhance-prompt: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Response class directly for image | |
| async def generate_image(request: image_schemas.GenerateImageRequest): | |
| """ | |
| Generates an image from either a raw or an enhanced prompt. | |
| Returns the image directly. | |
| """ | |
| logger.info("Received request to /generate-image") | |
| try: | |
| if request.enhanced_prompt: | |
| image_prompt = request.enhanced_prompt | |
| elif request.raw_prompt: | |
| image_prompt = request.raw_prompt | |
| else: | |
| logger.warning("Bad request to /generate-image: No prompt provided.") | |
| raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.") | |
| # Service returns text and BytesIO | |
| generated_text, image_bytes_io = image_service.generate_image_from_text(image_prompt) | |
| if not image_bytes_io: | |
| logger.error("Image generation failed or returned no image.") | |
| # If text was returned, perhaps send that in a JSON error response | |
| if generated_text: | |
| raise HTTPException(status_code=500, detail=f"Image generation failed. Model response: {generated_text}") | |
| raise HTTPException(status_code=500, detail="Image generation failed: No image data received.") | |
| logger.info("Successfully generated image. Streaming response.") | |
| return create_image_streaming_response(image_bytes_io) | |
| except HTTPException: # Re-raise HTTPExceptions directly | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in /generate-image: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- update_image (MODIFIED) --- | |
| # Response class is already correct | |
| async def update_image( | |
| image: UploadFile = File(..., description="The image to update (PNG, JPG)"), | |
| text_instruction: str = Form(..., description="The text instruction for what to change.") | |
| ): | |
| """ | |
| Updates an existing image using a text instruction. | |
| Returns the updated image directly. | |
| """ | |
| logger.info("Received request to /update-image") | |
| # Check if the uploaded file is an image | |
| if not image.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Image must be an image type (e.g., image/png, image/jpeg).") | |
| try: | |
| # Read the image bytes from the uploaded file | |
| image_bytes = await image.read() | |
| # Service returns text and BytesIO | |
| updated_text, updated_image_bytes_io = image_service.update_image_with_text( | |
| text_instruction=text_instruction, | |
| image_bytes=image_bytes # Pass the raw bytes to the service | |
| ) | |
| if not updated_image_bytes_io: | |
| logger.error("Image update failed or returned no image.") | |
| if updated_text: | |
| raise HTTPException(status_code=500, detail=f"Image update failed. Model response: {updated_text}") | |
| raise HTTPException(status_code=500, detail="Image update failed: No image data received.") | |
| logger.info("Successfully updated image. Streaming response.") | |
| return create_image_streaming_response(updated_image_bytes_io) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in /update-image: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- Corrected Virtual Try-On Endpoint with 3 Inputs --- | |
| async def virtual_try_on( | |
| dress_image: UploadFile = File(..., description="The dress image for try-on (PNG, JPG)"), | |
| person_image: UploadFile = File(..., description="The person image for try-on (PNG, JPG)"), | |
| shoes_image: Optional[UploadFile] = File(None, description="The shoes image for try-on (PNG, JPG)") | |
| ) -> Response: | |
| """ | |
| Performs a virtual try-on using dress, person, and optional shoes. | |
| Returns the try-on image directly or a JSON response with a summary if no image. | |
| """ | |
| logger.info("Received request to /virtual-try-on with image uploads.") | |
| # 1. Validate mandatory images | |
| if not dress_image.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Dress image must be an image type.") | |
| if not person_image.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Person image must be an image type.") | |
| # 2. Validate optional shoes | |
| # Note: We check 'shoes_image' is not None before checking content_type | |
| if shoes_image and not shoes_image.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Shoes image must be an image type.") | |
| try: | |
| # 3. Read bytes | |
| dress_image_bytes = await dress_image.read() | |
| person_image_bytes = await person_image.read() | |
| shoes_image_bytes = None | |
| if shoes_image: | |
| shoes_image_bytes = await shoes_image.read() | |
| # 4. Call Service | |
| summary, try_on_image_bytes_io = image_service.virtual_try_on( | |
| dress_image_bytes=dress_image_bytes, | |
| person_image_bytes=person_image_bytes, | |
| shoes_image_bytes=shoes_image_bytes | |
| ) | |
| # 5. Return Result | |
| if try_on_image_bytes_io: | |
| logger.info("Virtual try-on successful. Streaming image response.") | |
| return create_image_streaming_response(try_on_image_bytes_io) | |
| else: | |
| logger.warning("Virtual try-on returned no image. Sending JSON summary.") | |
| return JSONResponse(content=summary) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in /virtual-try-on: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_shoe(request: ShoeGenerateRequest): | |
| """ | |
| Generate shoe product images from a raw prompt. | |
| Streams the first generated image back. | |
| """ | |
| logger.info("Received request to /generate-shoe: %s", request.prompt[:80]) | |
| if not request.prompt: | |
| raise HTTPException(status_code=400, detail="`prompt` field is required.") | |
| try: | |
| # Call service with the raw prompt | |
| gen_text, images = generate_shoe_images(prompt=request.prompt) | |
| if images and len(images) > 0: | |
| # Stream the first image | |
| img_io: BytesIO = images[0] | |
| img_io.seek(0) | |
| return StreamingResponse(img_io, media_type="image/png") | |
| else: | |
| # Handle failure / text-only response | |
| detail = {"success": False, "notes": "No image generated."} | |
| if gen_text: | |
| detail["model_text"] = gen_text | |
| return JSONResponse(status_code=500, content=detail) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("Error in /generate-shoe: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| from app.core.config import settings | |
| from groq import Groq | |
| import json | |
| client = Groq(api_key=settings.LANGCHAIN_API_KEY) | |
| # --- Endpoint --- | |
| async def check_shoe(user_input: UserInput): | |
| """ | |
| Check if the user input text is related to shoe images using Groq model. | |
| Returns {"answer": "Yes"} or {"answer": "No"}. | |
| """ | |
| try: | |
| # Define the JSON schema for response | |
| response_format = { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "shoe_check", | |
| "schema": ShoeCheckResponse.model_json_schema() | |
| } | |
| } | |
| # System + user prompt | |
| messages = [ | |
| {"role": "system", "content": "You are a strict classifier that determines if text is about shoe images."}, | |
| {"role": "user", "content": user_input.text} | |
| ] | |
| # Call Groq model | |
| response = client.chat.completions.create( | |
| model="moonshotai/kimi-k2-instruct-0905", | |
| messages=messages, | |
| response_format=response_format | |
| ) | |
| # Parse model output | |
| result = ShoeCheckResponse.model_validate(json.loads(response.choices[0].message.content)) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- Dress Check Endpoint --- | |
| async def check_dress(user_input: UserInput): | |
| """ | |
| Check if the user input text is related to dress images using Groq model. | |
| Returns {"answer": "Yes"} or {"answer": "No"}. | |
| """ | |
| try: | |
| response_format = { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "dress_check", | |
| "schema": DressCheckResponse.model_json_schema() | |
| } | |
| } | |
| # System + user prompt | |
| messages = [ | |
| {"role": "system", "content": "You are a strict classifier that determines if text is about dress images."}, | |
| {"role": "user", "content": user_input.text} | |
| ] | |
| # Call Groq model | |
| response = client.chat.completions.create( | |
| model="moonshotai/kimi-k2-instruct-0905", | |
| messages=messages, | |
| response_format=response_format | |
| ) | |
| # Parse model output | |
| result = DressCheckResponse.model_validate(json.loads(response.choices[0].message.content)) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |