Virtual-try-on / app /api /image_router.py
Hammad712's picture
Update app/api/image_router.py
7bb1cd2 verified
import logging
from fastapi import APIRouter, HTTPException, Response, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from app.schemas import image as image_schemas
from app.schemas.image import ShoeGenerateRequest
from io import BytesIO
from typing import Optional
from app.services import image_service
from app.services.image_service import generate_shoe_images
from app.schemas.image import ShoeCheckResponse, UserInput, DressCheckResponse
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/v1",
tags=["Image Generation"]
)
# --- Helper function for streaming image responses ---
def create_image_streaming_response(image_data: BytesIO, media_type: str = "image/png"):
"""
Creates a StreamingResponse for an image from BytesIO data.
"""
if not image_data:
raise HTTPException(status_code=404, detail="Image not found or could not be generated.")
image_data.seek(0) # Ensure the pointer is at the beginning of the BytesIO object
return StreamingResponse(image_data, media_type=media_type)
# --- Endpoints ---
@router.post("/enhance-prompt", response_model=image_schemas.EnhancePromptResponse)
async def enhance_prompt(request: image_schemas.EnhancePromptRequest):
logger.info(f"Received request to /enhance-prompt for: {request.raw_prompt[:50]}...")
try:
enhanced_prompt = image_service.enhance_user_prompt(request.raw_prompt)
logger.info("Successfully enhanced prompt.")
return image_schemas.EnhancePromptResponse(
raw_prompt=request.raw_prompt,
enhanced_prompt=enhanced_prompt
)
except Exception as e:
logger.error(f"Error in /enhance-prompt: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-image", response_class=StreamingResponse) # Response class directly for image
async def generate_image(request: image_schemas.GenerateImageRequest):
"""
Generates an image from either a raw or an enhanced prompt.
Returns the image directly.
"""
logger.info("Received request to /generate-image")
try:
if request.enhanced_prompt:
image_prompt = request.enhanced_prompt
elif request.raw_prompt:
image_prompt = request.raw_prompt
else:
logger.warning("Bad request to /generate-image: No prompt provided.")
raise HTTPException(status_code=400, detail="Either raw_prompt or enhanced_prompt must be provided.")
# Service returns text and BytesIO
generated_text, image_bytes_io = image_service.generate_image_from_text(image_prompt)
if not image_bytes_io:
logger.error("Image generation failed or returned no image.")
# If text was returned, perhaps send that in a JSON error response
if generated_text:
raise HTTPException(status_code=500, detail=f"Image generation failed. Model response: {generated_text}")
raise HTTPException(status_code=500, detail="Image generation failed: No image data received.")
logger.info("Successfully generated image. Streaming response.")
return create_image_streaming_response(image_bytes_io)
except HTTPException: # Re-raise HTTPExceptions directly
raise
except Exception as e:
logger.error(f"Error in /generate-image: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# --- update_image (MODIFIED) ---
@router.post("/update-image", response_class=StreamingResponse) # Response class is already correct
async def update_image(
image: UploadFile = File(..., description="The image to update (PNG, JPG)"),
text_instruction: str = Form(..., description="The text instruction for what to change.")
):
"""
Updates an existing image using a text instruction.
Returns the updated image directly.
"""
logger.info("Received request to /update-image")
# Check if the uploaded file is an image
if not image.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Image must be an image type (e.g., image/png, image/jpeg).")
try:
# Read the image bytes from the uploaded file
image_bytes = await image.read()
# Service returns text and BytesIO
updated_text, updated_image_bytes_io = image_service.update_image_with_text(
text_instruction=text_instruction,
image_bytes=image_bytes # Pass the raw bytes to the service
)
if not updated_image_bytes_io:
logger.error("Image update failed or returned no image.")
if updated_text:
raise HTTPException(status_code=500, detail=f"Image update failed. Model response: {updated_text}")
raise HTTPException(status_code=500, detail="Image update failed: No image data received.")
logger.info("Successfully updated image. Streaming response.")
return create_image_streaming_response(updated_image_bytes_io)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in /update-image: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# --- Corrected Virtual Try-On Endpoint with 3 Inputs ---
@router.post("/virtual-try-on")
async def virtual_try_on(
dress_image: UploadFile = File(..., description="The dress image for try-on (PNG, JPG)"),
person_image: UploadFile = File(..., description="The person image for try-on (PNG, JPG)"),
shoes_image: Optional[UploadFile] = File(None, description="The shoes image for try-on (PNG, JPG)")
) -> Response:
"""
Performs a virtual try-on using dress, person, and optional shoes.
Returns the try-on image directly or a JSON response with a summary if no image.
"""
logger.info("Received request to /virtual-try-on with image uploads.")
# 1. Validate mandatory images
if not dress_image.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Dress image must be an image type.")
if not person_image.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Person image must be an image type.")
# 2. Validate optional shoes
# Note: We check 'shoes_image' is not None before checking content_type
if shoes_image and not shoes_image.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Shoes image must be an image type.")
try:
# 3. Read bytes
dress_image_bytes = await dress_image.read()
person_image_bytes = await person_image.read()
shoes_image_bytes = None
if shoes_image:
shoes_image_bytes = await shoes_image.read()
# 4. Call Service
summary, try_on_image_bytes_io = image_service.virtual_try_on(
dress_image_bytes=dress_image_bytes,
person_image_bytes=person_image_bytes,
shoes_image_bytes=shoes_image_bytes
)
# 5. Return Result
if try_on_image_bytes_io:
logger.info("Virtual try-on successful. Streaming image response.")
return create_image_streaming_response(try_on_image_bytes_io)
else:
logger.warning("Virtual try-on returned no image. Sending JSON summary.")
return JSONResponse(content=summary)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in /virtual-try-on: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-shoe", response_class=StreamingResponse)
async def generate_shoe(request: ShoeGenerateRequest):
"""
Generate shoe product images from a raw prompt.
Streams the first generated image back.
"""
logger.info("Received request to /generate-shoe: %s", request.prompt[:80])
if not request.prompt:
raise HTTPException(status_code=400, detail="`prompt` field is required.")
try:
# Call service with the raw prompt
gen_text, images = generate_shoe_images(prompt=request.prompt)
if images and len(images) > 0:
# Stream the first image
img_io: BytesIO = images[0]
img_io.seek(0)
return StreamingResponse(img_io, media_type="image/png")
else:
# Handle failure / text-only response
detail = {"success": False, "notes": "No image generated."}
if gen_text:
detail["model_text"] = gen_text
return JSONResponse(status_code=500, content=detail)
except HTTPException:
raise
except Exception as e:
logger.exception("Error in /generate-shoe: %s", e)
raise HTTPException(status_code=500, detail=str(e))
from app.core.config import settings
from groq import Groq
import json
client = Groq(api_key=settings.LANGCHAIN_API_KEY)
# --- Endpoint ---
@router.post("/check-shoe", response_model=ShoeCheckResponse)
async def check_shoe(user_input: UserInput):
"""
Check if the user input text is related to shoe images using Groq model.
Returns {"answer": "Yes"} or {"answer": "No"}.
"""
try:
# Define the JSON schema for response
response_format = {
"type": "json_schema",
"json_schema": {
"name": "shoe_check",
"schema": ShoeCheckResponse.model_json_schema()
}
}
# System + user prompt
messages = [
{"role": "system", "content": "You are a strict classifier that determines if text is about shoe images."},
{"role": "user", "content": user_input.text}
]
# Call Groq model
response = client.chat.completions.create(
model="moonshotai/kimi-k2-instruct-0905",
messages=messages,
response_format=response_format
)
# Parse model output
result = ShoeCheckResponse.model_validate(json.loads(response.choices[0].message.content))
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- Dress Check Endpoint ---
@router.post("/check-dress", response_model=DressCheckResponse)
async def check_dress(user_input: UserInput):
"""
Check if the user input text is related to dress images using Groq model.
Returns {"answer": "Yes"} or {"answer": "No"}.
"""
try:
response_format = {
"type": "json_schema",
"json_schema": {
"name": "dress_check",
"schema": DressCheckResponse.model_json_schema()
}
}
# System + user prompt
messages = [
{"role": "system", "content": "You are a strict classifier that determines if text is about dress images."},
{"role": "user", "content": user_input.text}
]
# Call Groq model
response = client.chat.completions.create(
model="moonshotai/kimi-k2-instruct-0905",
messages=messages,
response_format=response_format
)
# Parse model output
result = DressCheckResponse.model_validate(json.loads(response.choices[0].message.content))
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))