File size: 2,868 Bytes
eaed1b7
 
dce5085
eaed1b7
 
 
 
 
 
 
 
 
 
 
 
 
dce5085
eaed1b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dce5085
 
eaed1b7
 
 
 
 
 
dce5085
 
 
 
eaed1b7
 
 
dce5085
 
eaed1b7
dce5085
eaed1b7
dce5085
eaed1b7
 
dce5085
 
 
 
 
 
 
 
 
 
 
 
 
eaed1b7
 
 
dce5085
 
eaed1b7
dce5085
eaed1b7
 
 
 
 
 
dce5085
eaed1b7
 
 
 
dce5085
eaed1b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
import torch
import threading
from pathlib import Path
from flask import Flask, request, jsonify, send_file
from diffusers import StableDiffusionPipeline
from datetime import datetime
import io

app = Flask(__name__)

OUTPUT_DIR = Path("/app/generated_images")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = None
model_loaded = False
lock = threading.Lock()

def load_model():
    global pipeline, model_loaded
    try:
        print(f"📱 Appareil: {DEVICE}")
        dtype = torch.float32 if DEVICE == "cpu" else torch.float16
        pipeline = StableDiffusionPipeline.from_pretrained(
            "K2MAR/mon-modele-sd",
            torch_dtype=dtype,
            safety_checker=None
        ).to(DEVICE)
        pipeline.enable_attention_slicing()
        model_loaded = True
        print("✅ Modèle prêt!")
        return True
    except Exception as e:
        print(f"❌ Erreur: {e}")
        return False

@app.route('/health', methods=['GET'])
def health():
    return jsonify({
        "status": "ok" if model_loaded else "loading",
        "device": DEVICE,
        "model_loaded": model_loaded,
        "busy": lock.locked()
    })

@app.route('/generate', methods=['POST'])
def generate():
    if not model_loaded:
        return jsonify({"error": "Model not loaded"}), 503

    if lock.locked():
        return jsonify({"error": "Server busy, try again later"}), 503

    try:
        data = request.get_json()
        if not data or "prompt" not in data:
            return jsonify({"error": "Missing 'prompt'"}), 400

        prompt = data.get("prompt", "")
        steps = min(int(data.get("steps", 20)), 30)
        guidance_scale = float(data.get("guidance_scale", 7.5))

        if not prompt:
            return jsonify({"error": "Prompt cannot be empty"}), 400

        print(f"\n🎨 Génération: {prompt} ({steps} steps)")

        with lock:
            with torch.no_grad():
                image = pipeline(
                    prompt,
                    num_inference_steps=steps,
                    guidance_scale=guidance_scale,
                    height=512,
                    width=512
                ).images[0]

        img_io = io.BytesIO()
        image.save(img_io, 'PNG')
        img_io.seek(0)

        print(f"✅ Image générée!\n")
        return send_file(img_io, mimetype='image/png')

    except Exception as e:
        print(f"❌ Erreur: {str(e)}\n")
        return jsonify({"error": str(e)}), 500

@app.route('/', methods=['GET'])
def home():
    return jsonify({"service": "SD API", "model_loaded": model_loaded, "device": DEVICE})

if __name__ == '__main__':
    if not load_model():
        exit(1)
    print("\n🚀 Serveur sur 0.0.0.0:7860\n")
    app.run(host='0.0.0.0', port=7860, debug=False, threaded=True)