File size: 3,417 Bytes
3af9bb2 8be12ab 33fcfe8 3af9bb2 33fcfe8 3af9bb2 33fcfe8 21441f0 8be12ab ec20299 3af9bb2 33fcfe8 3af9bb2 a40654e 33fcfe8 a40654e 33fcfe8 3af9bb2 33fcfe8 3af9bb2 ec20299 33fcfe8 3af9bb2 3806071 3af9bb2 3806071 3af9bb2 33fcfe8 3af9bb2 33fcfe8 3af9bb2 33fcfe8 3af9bb2 1ff4547 f9c8cf7 8be12ab ec20299 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import asyncio
import tempfile
import os
from threading import RLock
from huggingface_hub import InferenceClient
from PIL import Image # Import Pillow
from io import BytesIO # For converting image to bytes
myapp = Flask(__name__)
CORS(myapp) # Enable CORS for all routes
lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
inference_timeout = 600 # Set timeout for inference
@myapp.route('/')
def home():
return "Welcome to the Image Background Remover!"
# Function to dynamically load models from the "models" list
def get_model_from_name(model_name):
return model_name if model_name in models else None
# Asynchronous function to perform inference
async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
task = asyncio.create_task(
asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
)
await asyncio.sleep(0)
try:
result = await asyncio.wait_for(task, timeout=timeout)
except (Exception, asyncio.TimeoutError) as e:
print(e)
print(f"Task timed out for model: {model}")
if not task.done():
task.cancel()
result = None
if task.done() and result is not None:
with lock:
# Convert image result to bytes using Pillow
image_bytes = BytesIO()
# Assuming result is an image object from huggingface_hub
result.save(image_bytes, format='PNG') # Save the image to a BytesIO object
image_bytes.seek(0) # Go to the start of the byte stream
# Save the result image as a temporary file
temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
with open(temp_image.name, "wb") as f:
f.write(image_bytes.read()) # Write the bytes to the temp file
return temp_image.name # Return the path to the saved image
return None
# Flask route for the API endpoint
@myapp.route('/generate_image', methods=['POST'])
def generate_image():
data = request.get_json()
# Extract required fields from the request
prompt = data.get('prompt', '')
seed = data.get('seed', 1)
model_name = data.get('model', 'prompthero/openjourney-v4') # Default model
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
# Get the model from all_models
model = get_model_from_name(model_name)
if not model:
return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
try:
# Create a generic InferenceClient for the model
client = InferenceClient(token=HF_TOKEN)
# Call the async inference function
result_path = asyncio.run(infer(client, prompt, seed, model=model))
if result_path:
return send_file(result_path, mimetype='image/png') # Send back the generated image file
else:
return jsonify({"error": "Failed to generate image"}), 500
except Exception as e:
print(f"Error in generate_image: {str(e)}") # Log the error
return jsonify({"error": str(e)}), 500
# Add this block to make sure your app runs when called
if __name__ == "__main__":
myapp.run(host='0.0.0.0', port=7860) # Run directly if needed for testing |