Spaces:
Paused
Paused
| from fastapi import FastAPI, File, UploadFile, HTTPException, Request | |
| from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from gradio_client import Client, handle_file | |
| import os | |
| import tempfile | |
| import base64 | |
| app = FastAPI() | |
| # Retrieve Hugging Face token from environment variables | |
| hf_token = os.environ.get('HF_TOKEN') | |
| client = Client("Ashrafb/image-to-sketch", hf_token=hf_token) | |
| # Configure CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def upload_file(file: UploadFile = File(...)): | |
| # Save the uploaded file to a temporary location | |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| temp_file.write(await file.read()) | |
| temp_file_path = temp_file.name | |
| try: | |
| # Call the Gradio API | |
| result = client.predict( | |
| img=handle_file(temp_file_path), | |
| api_name="/predict" | |
| ) | |
| # Check if the result is valid | |
| if result and len(result) == 2: | |
| # Convert the sketch image to a base64 string | |
| with open(result[0], "rb") as image_file: | |
| image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| return { | |
| "sketch_image_base64": f"data:image/png;base64,{image_data}", | |
| "result_file": result[1] | |
| } | |
| else: | |
| return {"error": "Invalid result from the prediction API."} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| finally: | |
| # Clean up the temporary file | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |