#!/usr/bin/env python3 import torch import threading from pathlib import Path from flask import Flask, request, jsonify, send_file from diffusers import StableDiffusionPipeline from datetime import datetime import io app = Flask(__name__) OUTPUT_DIR = Path("/app/generated_images") OUTPUT_DIR.mkdir(exist_ok=True, parents=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" pipeline = None model_loaded = False lock = threading.Lock() def load_model(): global pipeline, model_loaded try: print(f"đŸ“± Appareil: {DEVICE}") dtype = torch.float32 if DEVICE == "cpu" else torch.float16 pipeline = StableDiffusionPipeline.from_pretrained( "K2MAR/mon-modele-sd", torch_dtype=dtype, safety_checker=None ).to(DEVICE) pipeline.enable_attention_slicing() model_loaded = True print("✅ ModĂšle prĂȘt!") return True except Exception as e: print(f"❌ Erreur: {e}") return False @app.route('/health', methods=['GET']) def health(): return jsonify({ "status": "ok" if model_loaded else "loading", "device": DEVICE, "model_loaded": model_loaded, "busy": lock.locked() }) @app.route('/generate', methods=['POST']) def generate(): if not model_loaded: return jsonify({"error": "Model not loaded"}), 503 if lock.locked(): return jsonify({"error": "Server busy, try again later"}), 503 try: data = request.get_json() if not data or "prompt" not in data: return jsonify({"error": "Missing 'prompt'"}), 400 prompt = data.get("prompt", "") steps = min(int(data.get("steps", 20)), 30) guidance_scale = float(data.get("guidance_scale", 7.5)) if not prompt: return jsonify({"error": "Prompt cannot be empty"}), 400 print(f"\n🎹 GĂ©nĂ©ration: {prompt} ({steps} steps)") with lock: with torch.no_grad(): image = pipeline( prompt, num_inference_steps=steps, guidance_scale=guidance_scale, height=512, width=512 ).images[0] img_io = io.BytesIO() image.save(img_io, 'PNG') img_io.seek(0) print(f"✅ Image gĂ©nĂ©rĂ©e!\n") return send_file(img_io, mimetype='image/png') except Exception as e: print(f"❌ Erreur: {str(e)}\n") return jsonify({"error": str(e)}), 500 @app.route('/', methods=['GET']) def home(): return jsonify({"service": "SD API", "model_loaded": model_loaded, "device": DEVICE}) if __name__ == '__main__': if not load_model(): exit(1) print("\n🚀 Serveur sur 0.0.0.0:7860\n") app.run(host='0.0.0.0', port=7860, debug=False, threaded=True)