Itsdockertest / main.py
Ashrafb's picture
Update main.py
5e78887 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, handle_file
import os
import tempfile
import base64
app = FastAPI()
# Retrieve Hugging Face token from environment variables
hf_token = os.environ.get('HF_TOKEN')
client = Client("Ashrafb/image-to-sketch", hf_token=hf_token)
# Configure CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
# Save the uploaded file to a temporary location
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(await file.read())
temp_file_path = temp_file.name
try:
# Call the Gradio API
result = client.predict(
img=handle_file(temp_file_path),
api_name="/predict"
)
# Check if the result is valid
if result and len(result) == 2:
# Convert the sketch image to a base64 string
with open(result[0], "rb") as image_file:
image_data = base64.b64encode(image_file.read()).decode("utf-8")
return {
"sketch_image_base64": f"data:image/png;base64,{image_data}",
"result_file": result[1]
}
else:
return {"error": "Invalid result from the prediction API."}
except Exception as e:
return {"error": str(e)}
finally:
# Clean up the temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)