jree423 commited on
Commit
86771a8
·
verified ·
1 Parent(s): 8d52022

Update: Improve API implementation

Browse files
Files changed (1) hide show
  1. api.py +41 -43
api.py CHANGED
@@ -1,58 +1,56 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
  from pydantic import BaseModel
3
- from typing import Dict, Any, Optional, List, Union
4
- import base64
5
- import io
6
- from PIL import Image
7
- import torch
8
  import os
 
9
  import sys
10
- import json
11
-
12
- # Import the handler
13
  from handler import EndpointHandler
14
 
15
- # Initialize the app
 
 
 
 
 
 
16
  app = FastAPI()
17
 
18
- # Initialize the model
19
- model = EndpointHandler(model_dir="/code")
 
 
 
 
 
 
 
 
 
20
 
21
  @app.post("/")
22
- async def process_request(request: Request):
23
  try:
24
- # Get the raw request body
25
- body = await request.body()
26
 
27
- # Try to parse as JSON
28
- try:
29
- data = json.loads(body)
30
- except:
31
- # If not JSON, treat as plain text
32
- data = {"inputs": body.decode("utf-8")}
33
 
34
- # Handle different input formats
35
- if isinstance(data, dict):
36
- if "inputs" in data:
37
- # Standard format
38
- pass
39
- elif "text" in data:
40
- # Text field directly
41
- data = {"inputs": data["text"]}
42
- else:
43
- # No recognized fields, use the whole dict as input
44
- data = {"inputs": str(data)}
45
  else:
46
- # Not a dict, use as is
47
- data = {"inputs": str(data)}
48
-
49
- # Process the request
50
- result = model(data)
51
- return result
52
  except Exception as e:
 
 
 
53
  raise HTTPException(status_code=500, detail=str(e))
54
-
55
- # Add a health check endpoint
56
- @app.get("/health")
57
- async def health():
58
- return {"status": "ok"}
 
1
+ from fastapi import FastAPI, Response, HTTPException
2
  from pydantic import BaseModel
3
+ from typing import Union, Dict, Any, Optional
 
 
 
 
4
  import os
5
+ import io
6
  import sys
 
 
 
7
  from handler import EndpointHandler
8
 
9
+ # Add debug logging
10
+ def debug_log(message):
11
+ print(f"DEBUG API: {message}")
12
+ sys.stdout.flush()
13
+
14
+ debug_log("Initializing API")
15
+
16
  app = FastAPI()
17
 
18
+ # Initialize the handler with the model directory
19
+ model_dir = os.environ.get("MODEL_DIR", "/code")
20
+ debug_log(f"Using model_dir: {model_dir}")
21
+ handler = EndpointHandler(model_dir)
22
+
23
+ class TextRequest(BaseModel):
24
+ inputs: Union[str, Dict[str, Any]]
25
+
26
+ @app.get("/")
27
+ def read_root():
28
+ return {"message": "Vector Graphics Generation API"}
29
 
30
  @app.post("/")
31
+ async def generate(request: TextRequest):
32
  try:
33
+ debug_log(f"Received request: {request}")
 
34
 
35
+ # Call the handler
36
+ result = handler(request.dict())
37
+ debug_log(f"Handler returned result of type: {type(result)}")
 
 
 
38
 
39
+ # If the result is a PIL Image, convert it to bytes
40
+ if hasattr(result, "save"):
41
+ debug_log("Converting PIL Image to bytes")
42
+ img_byte_arr = io.BytesIO()
43
+ result.save(img_byte_arr, format="PNG")
44
+ img_byte_arr.seek(0)
45
+
46
+ # Return the image as a response
47
+ return Response(content=img_byte_arr.getvalue(), media_type="image/png")
 
 
48
  else:
49
+ # Return the result as JSON
50
+ debug_log("Returning result as JSON")
51
+ return result
 
 
 
52
  except Exception as e:
53
+ debug_log(f"Error in generate endpoint: {e}")
54
+ import traceback
55
+ debug_log(traceback.format_exc())
56
  raise HTTPException(status_code=500, detail=str(e))