K2MAR commited on
Commit
eaed1b7
·
1 Parent(s): 46f07b1

Deploy SD API

Browse files
Files changed (2) hide show
  1. Dockerfile +28 -0
  2. api_server.py +128 -0
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONUNBUFFERED=1
4
+
5
+ WORKDIR /app
6
+
7
+ COPY api_server.py /app/
8
+
9
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
10
+
11
+ RUN pip install --no-cache-dir \
12
+ torch torchvision torchaudio \
13
+ --index-url https://download.pytorch.org/whl/cpu
14
+
15
+ RUN pip install --no-cache-dir \
16
+ diffusers \
17
+ transformers \
18
+ safetensors \
19
+ peft \
20
+ flask \
21
+ pillow \
22
+ accelerate
23
+
24
+ RUN mkdir -p /app/model /app/generated_images
25
+
26
+ EXPOSE 7860
27
+
28
+ CMD ["python3", "-u", "/app/api_server.py"]
api_server.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ API Flask pour générer des images
4
+ Endpoint unique: /generate
5
+ """
6
+
7
+ import torch
8
+ from pathlib import Path
9
+ from flask import Flask, request, jsonify, send_file
10
+ from diffusers import StableDiffusionPipeline
11
+ from datetime import datetime
12
+ import io
13
+
14
+ app = Flask(__name__)
15
+
16
+ # Configuration
17
+ OUTPUT_DIR = Path("/app/generated_images")
18
+ OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
19
+
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # Variable globale
23
+ pipeline = None
24
+ model_loaded = False
25
+
26
+ def load_model():
27
+ """Charge le modèle au démarrage"""
28
+ global pipeline, model_loaded
29
+
30
+ print("\n" + "="*70)
31
+ print("🤖 Chargement du modèle fusionné...")
32
+ print("="*70 + "\n")
33
+
34
+ try:
35
+ print(f"📱 Appareil: {DEVICE}")
36
+
37
+ dtype = torch.float32 if DEVICE == "cpu" else torch.float16
38
+
39
+ pipeline = StableDiffusionPipeline.from_pretrained(
40
+ "K2MAR/mon-modele-sd",
41
+ torch_dtype=dtype,
42
+ safety_checker=None
43
+ ).to(DEVICE)
44
+
45
+ pipeline.enable_attention_slicing()
46
+
47
+ model_loaded = True
48
+ print("="*70)
49
+ print("✅ Modèle prêt!")
50
+ print("="*70 + "\n")
51
+ return True
52
+ except Exception as e:
53
+ print(f"❌ Erreur: {e}")
54
+ return False
55
+
56
+ @app.route('/health', methods=['GET'])
57
+ def health():
58
+ return jsonify({
59
+ "status": "ok" if model_loaded else "loading",
60
+ "device": DEVICE,
61
+ "model_loaded": model_loaded
62
+ })
63
+
64
+ @app.route('/generate', methods=['POST'])
65
+ def generate():
66
+ if not model_loaded:
67
+ return jsonify({"error": "Model not loaded"}), 503
68
+
69
+ try:
70
+ data = request.get_json()
71
+
72
+ if not data or "prompt" not in data:
73
+ return jsonify({"error": "Missing 'prompt' in request"}), 400
74
+
75
+ prompt = data.get("prompt", "")
76
+ steps = int(data.get("steps", 30))
77
+ guidance_scale = float(data.get("guidance_scale", 7.5))
78
+
79
+ if not prompt:
80
+ return jsonify({"error": "Prompt cannot be empty"}), 400
81
+
82
+ if steps < 1 or steps > 50:
83
+ return jsonify({"error": "Steps must be 1-50"}), 400
84
+
85
+ print(f"\n🎨 Génération: {prompt}")
86
+
87
+ with torch.no_grad():
88
+ image = pipeline(
89
+ prompt,
90
+ num_inference_steps=steps,
91
+ guidance_scale=guidance_scale,
92
+ height=512,
93
+ width=512
94
+ ).images[0]
95
+
96
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
97
+ filepath = OUTPUT_DIR / f"generated_{timestamp}.png"
98
+ image.save(filepath)
99
+
100
+ img_io = io.BytesIO()
101
+ image.save(img_io, 'PNG')
102
+ img_io.seek(0)
103
+
104
+ return send_file(img_io, mimetype='image/png')
105
+
106
+ except Exception as e:
107
+ print(f"❌ Erreur: {str(e)}\n")
108
+ return jsonify({"error": str(e)}), 500
109
+
110
+ @app.route('/', methods=['GET'])
111
+ def home():
112
+ return jsonify({
113
+ "service": "LoRA Solar Panel Generator API",
114
+ "version": "1.0",
115
+ "device": DEVICE,
116
+ "model_loaded": model_loaded,
117
+ "endpoints": {
118
+ "health": "GET /health",
119
+ "generate": "POST /generate"
120
+ }
121
+ })
122
+
123
+ if __name__ == '__main__':
124
+ if not load_model():
125
+ exit(1)
126
+
127
+ print("\n🚀 Serveur démarrage sur 0.0.0.0:7860\n")
128
+ app.run(host='0.0.0.0', port=7860, debug=False, threaded=True)