File size: 3,417 Bytes
3af9bb2
8be12ab
33fcfe8
 
3af9bb2
33fcfe8
3af9bb2
33fcfe8
21441f0
8be12ab
ec20299
 
3af9bb2
33fcfe8
 
 
 
3af9bb2
a40654e
 
 
 
33fcfe8
 
 
a40654e
33fcfe8
 
 
 
 
 
3af9bb2
33fcfe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3af9bb2
 
ec20299
33fcfe8
3af9bb2
3806071
3af9bb2
 
 
3806071
3af9bb2
 
 
 
33fcfe8
 
 
 
 
3af9bb2
33fcfe8
 
 
 
 
 
 
3af9bb2
 
 
33fcfe8
3af9bb2
1ff4547
f9c8cf7
8be12ab
ec20299
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import asyncio
import tempfile
import os
from threading import RLock
from huggingface_hub import InferenceClient
from PIL import Image  # Import Pillow
from io import BytesIO  # For converting image to bytes

myapp = Flask(__name__)
CORS(myapp)  # Enable CORS for all routes

lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN")  # Hugging Face token

inference_timeout = 600  # Set timeout for inference

@myapp.route('/')
def home():
    return "Welcome to the Image Background Remover!"

# Function to dynamically load models from the "models" list
def get_model_from_name(model_name):
    return model_name if model_name in models else None

# Asynchronous function to perform inference
async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
    task = asyncio.create_task(
        asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
    )
    await asyncio.sleep(0)
    try:
        result = await asyncio.wait_for(task, timeout=timeout)
    except (Exception, asyncio.TimeoutError) as e:
        print(e)
        print(f"Task timed out for model: {model}")
        if not task.done():
            task.cancel()
        result = None
    
    if task.done() and result is not None:
        with lock:
            # Convert image result to bytes using Pillow
            image_bytes = BytesIO()
            # Assuming result is an image object from huggingface_hub
            result.save(image_bytes, format='PNG')  # Save the image to a BytesIO object
            image_bytes.seek(0)  # Go to the start of the byte stream
            
            # Save the result image as a temporary file
            temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
            with open(temp_image.name, "wb") as f:
                f.write(image_bytes.read())  # Write the bytes to the temp file

        return temp_image.name  # Return the path to the saved image
    return None

# Flask route for the API endpoint
@myapp.route('/generate_image', methods=['POST'])
def generate_image():
    data = request.get_json()
    
    # Extract required fields from the request
    prompt = data.get('prompt', '')
    seed = data.get('seed', 1)
    model_name = data.get('model', 'prompthero/openjourney-v4')  # Default model

    if not prompt:
        return jsonify({"error": "Prompt is required"}), 400

    # Get the model from all_models
    model = get_model_from_name(model_name)
    if not model:
        return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400

    try:
        # Create a generic InferenceClient for the model
        client = InferenceClient(token=HF_TOKEN)

        # Call the async inference function
        result_path = asyncio.run(infer(client, prompt, seed, model=model))
        if result_path:
            return send_file(result_path, mimetype='image/png')  # Send back the generated image file
        else:
            return jsonify({"error": "Failed to generate image"}), 500
    except Exception as e:
        print(f"Error in generate_image: {str(e)}")  # Log the error
        return jsonify({"error": str(e)}), 500

# Add this block to make sure your app runs when called
if __name__ == "__main__":
    myapp.run(host='0.0.0.0', port=7860)  # Run directly if needed for testing