File size: 1,930 Bytes
a5cec1c
 
f787272
2c151eb
a5cec1c
47bac16
2c151eb
6057ba6
a5cec1c
 
e72c238
05761be
a5cec1c
2c151eb
a5cec1c
53afd1c
a5cec1c
c1e41b5
 
 
 
 
 
 
215cdaf
29da619
 
53849fa
 
 
 
 
2371a34
53849fa
1bf6c45
 
 
 
b396701
53849fa
bff39e2
53849fa
ae9a263
 
 
bff39e2
 
 
 
ae9a263
bff39e2
215cdaf
 
2371a34
53849fa
215cdaf
53849fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, handle_file
import os
import tempfile
import base64

app = FastAPI()

# Retrieve Hugging Face token from environment variables
hf_token = os.environ.get('HF_TOKEN')
client = Client("Ashrafb/image-to-sketch-ari", hf_token=hf_token)

# Configure CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust as needed, '*' allows requests from any origin
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
    # Save the uploaded file to a temporary location
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        temp_file.write(await file.read())
        temp_file_path = temp_file.name

    try:
        # Call the Gradio API
        result = client.predict(
            img=handle_file(temp_file_path),
            api_name="/predict"
        )
        
        # Check if the result is valid
        if result and len(result) == 2:
            # Convert the sketch image to a base64 string
            with open(result[0], "rb") as image_file:
                image_data = base64.b64encode(image_file.read()).decode("utf-8")

            return {
                "sketch_image_base64": f"data:image/png;base64,{image_data}",
                "result_file": result[1]
            }
        else:
            return {"error": "Invalid result from the prediction API."}
    except Exception as e:
        return {"error": str(e)}
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)