File size: 1,926 Bytes
cc0b0fb
 
f787272
2c151eb
cc0b0fb
 
2c151eb
cc0b0fb
 
 
e72c238
05761be
cc0b0fb
2c151eb
b84f01e
05761be
cc0b0fb
c1e41b5
 
 
 
 
 
 
215cdaf
29da619
 
cc0b0fb
2371a34
 
 
 
 
cc0b0fb
0151799
cc0b0fb
0151799
 
b396701
ae9a263
bff39e2
cc0b0fb
ae9a263
 
 
bff39e2
 
 
 
ae9a263
bff39e2
215cdaf
 
2371a34
cc0b0fb
215cdaf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, handle_file
import os
import tempfile
import base64

app = FastAPI()

# Retrieve Hugging Face token from environment variables
hf_token = os.environ.get('HF_TOKEN')
client = Client("Ashrafb/image-to-sketch", hf_token=hf_token)

# Configure CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust as needed, '*' allows requests from any origin
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
    # Save the uploaded file to a temporary location
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        temp_file.write(await file.read())
        temp_file_path = temp_file.name

    try:
        # Call the Gradio API
        result = client.predict(
            img=handle_file(temp_file_path),
            api_name="/predict"
        )
        
        # Check if the result is valid
        if result and len(result) == 2:
            # Convert the sketch image to a base64 string
            with open(result[0], "rb") as image_file:
                image_data = base64.b64encode(image_file.read()).decode("utf-8")

            return {
                "sketch_image_base64": f"data:image/png;base64,{image_data}",
                "result_file": result[1]
            }
        else:
            return {"error": "Invalid result from the prediction API."}
    except Exception as e:
        return {"error": str(e)}
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)