Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import torch | |
| import threading | |
| from pathlib import Path | |
| from flask import Flask, request, jsonify, send_file | |
| from diffusers import StableDiffusionPipeline | |
| from datetime import datetime | |
| import io | |
| app = Flask(__name__) | |
| OUTPUT_DIR = Path("/app/generated_images") | |
| OUTPUT_DIR.mkdir(exist_ok=True, parents=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipeline = None | |
| model_loaded = False | |
| lock = threading.Lock() | |
| def load_model(): | |
| global pipeline, model_loaded | |
| try: | |
| print(f"📱 Appareil: {DEVICE}") | |
| dtype = torch.float32 if DEVICE == "cpu" else torch.float16 | |
| pipeline = StableDiffusionPipeline.from_pretrained( | |
| "K2MAR/mon-modele-sd", | |
| torch_dtype=dtype, | |
| safety_checker=None | |
| ).to(DEVICE) | |
| pipeline.enable_attention_slicing() | |
| model_loaded = True | |
| print("✅ Modèle prêt!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Erreur: {e}") | |
| return False | |
| def health(): | |
| return jsonify({ | |
| "status": "ok" if model_loaded else "loading", | |
| "device": DEVICE, | |
| "model_loaded": model_loaded, | |
| "busy": lock.locked() | |
| }) | |
| def generate(): | |
| if not model_loaded: | |
| return jsonify({"error": "Model not loaded"}), 503 | |
| if lock.locked(): | |
| return jsonify({"error": "Server busy, try again later"}), 503 | |
| try: | |
| data = request.get_json() | |
| if not data or "prompt" not in data: | |
| return jsonify({"error": "Missing 'prompt'"}), 400 | |
| prompt = data.get("prompt", "") | |
| steps = min(int(data.get("steps", 20)), 30) | |
| guidance_scale = float(data.get("guidance_scale", 7.5)) | |
| if not prompt: | |
| return jsonify({"error": "Prompt cannot be empty"}), 400 | |
| print(f"\n🎨 Génération: {prompt} ({steps} steps)") | |
| with lock: | |
| with torch.no_grad(): | |
| image = pipeline( | |
| prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| height=512, | |
| width=512 | |
| ).images[0] | |
| img_io = io.BytesIO() | |
| image.save(img_io, 'PNG') | |
| img_io.seek(0) | |
| print(f"✅ Image générée!\n") | |
| return send_file(img_io, mimetype='image/png') | |
| except Exception as e: | |
| print(f"❌ Erreur: {str(e)}\n") | |
| return jsonify({"error": str(e)}), 500 | |
| def home(): | |
| return jsonify({"service": "SD API", "model_loaded": model_loaded, "device": DEVICE}) | |
| if __name__ == '__main__': | |
| if not load_model(): | |
| exit(1) | |
| print("\n🚀 Serveur sur 0.0.0.0:7860\n") | |
| app.run(host='0.0.0.0', port=7860, debug=False, threaded=True) | |