Spaces:
Paused
Paused
File size: 1,926 Bytes
cc0b0fb f787272 2c151eb cc0b0fb 2c151eb cc0b0fb e72c238 05761be cc0b0fb 2c151eb b84f01e 05761be cc0b0fb c1e41b5 215cdaf 29da619 cc0b0fb 2371a34 cc0b0fb 0151799 cc0b0fb 0151799 b396701 ae9a263 bff39e2 cc0b0fb ae9a263 bff39e2 ae9a263 bff39e2 215cdaf 2371a34 cc0b0fb 215cdaf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, handle_file
import os
import tempfile
import base64
app = FastAPI()
# Retrieve Hugging Face token from environment variables
hf_token = os.environ.get('HF_TOKEN')
client = Client("Ashrafb/image-to-sketch", hf_token=hf_token)
# Configure CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
# Save the uploaded file to a temporary location
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(await file.read())
temp_file_path = temp_file.name
try:
# Call the Gradio API
result = client.predict(
img=handle_file(temp_file_path),
api_name="/predict"
)
# Check if the result is valid
if result and len(result) == 2:
# Convert the sketch image to a base64 string
with open(result[0], "rb") as image_file:
image_data = base64.b64encode(image_file.read()).decode("utf-8")
return {
"sketch_image_base64": f"data:image/png;base64,{image_data}",
"result_file": result[1]
}
else:
return {"error": "Invalid result from the prediction API."}
except Exception as e:
return {"error": str(e)}
finally:
# Clean up the temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
|