diffsketcher / api.py
jree423's picture
Update: Improve API implementation
86771a8 verified
raw
history blame
1.72 kB
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any, Optional
import os
import io
import sys
from handler import EndpointHandler
# Add debug logging
def debug_log(message):
print(f"DEBUG API: {message}")
sys.stdout.flush()
debug_log("Initializing API")
app = FastAPI()
# Initialize the handler with the model directory
model_dir = os.environ.get("MODEL_DIR", "/code")
debug_log(f"Using model_dir: {model_dir}")
handler = EndpointHandler(model_dir)
class TextRequest(BaseModel):
inputs: Union[str, Dict[str, Any]]
@app.get("/")
def read_root():
return {"message": "Vector Graphics Generation API"}
@app.post("/")
async def generate(request: TextRequest):
try:
debug_log(f"Received request: {request}")
# Call the handler
result = handler(request.dict())
debug_log(f"Handler returned result of type: {type(result)}")
# If the result is a PIL Image, convert it to bytes
if hasattr(result, "save"):
debug_log("Converting PIL Image to bytes")
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
# Return the image as a response
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
else:
# Return the result as JSON
debug_log("Returning result as JSON")
return result
except Exception as e:
debug_log(f"Error in generate endpoint: {e}")
import traceback
debug_log(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))