| from fastapi import FastAPI, File, UploadFile |
| from fastapi import FastAPI, File, UploadFile, Form, Request |
| from fastapi.responses import HTMLResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import JSONResponse |
| from fastapi.responses import StreamingResponse |
| from gradio_client import Client |
| import os |
| import io |
| app = FastAPI() |
|
|
| hf_token = os.environ.get('HF_TOKEN') |
| client = Client("https://ashrafb-image-to-sketch3.hf.space/", hf_token=hf_token) |
|
|
| import tempfile |
|
|
|
|
|
|
| import base64 |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| @app.post("/upload/") |
| async def upload_file(file: UploadFile = File(...)): |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
| temp_file.write(await file.read()) |
| temp_file_path = temp_file.name |
|
|
| try: |
| result = client.predict(temp_file_path, api_name="/predict") |
| |
| |
| if result and len(result) == 2: |
| |
| with open(result[0], "rb") as image_file: |
| image_data = base64.b64encode(image_file.read()).decode("utf-8") |
|
|
| return { |
| "sketch_image_base64": f"data:image/png;base64,{image_data}", |
| "result_file": result[1] |
| } |
| else: |
| return {"error": "Invalid result from the prediction API."} |
| except Exception as e: |
| return {"error": str(e)} |
| finally: |
| if os.path.exists(temp_file_path): |
| os.unlink(temp_file_path) |
|
|
|
|
|
|