K2MAR commited on
Commit
dce5085
·
1 Parent(s): eaed1b7

Fix threading lock

Browse files
Files changed (1) hide show
  1. api_server.py +30 -61
api_server.py CHANGED
@@ -1,10 +1,6 @@
1
  #!/usr/bin/env python3
2
- """
3
- API Flask pour générer des images
4
- Endpoint unique: /generate
5
- """
6
-
7
  import torch
 
8
  from pathlib import Path
9
  from flask import Flask, request, jsonify, send_file
10
  from diffusers import StableDiffusionPipeline
@@ -13,41 +9,26 @@ import io
13
 
14
  app = Flask(__name__)
15
 
16
- # Configuration
17
  OUTPUT_DIR = Path("/app/generated_images")
18
  OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
19
-
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
-
22
- # Variable globale
23
  pipeline = None
24
  model_loaded = False
 
25
 
26
  def load_model():
27
- """Charge le modèle au démarrage"""
28
  global pipeline, model_loaded
29
-
30
- print("\n" + "="*70)
31
- print("🤖 Chargement du modèle fusionné...")
32
- print("="*70 + "\n")
33
-
34
  try:
35
  print(f"📱 Appareil: {DEVICE}")
36
-
37
  dtype = torch.float32 if DEVICE == "cpu" else torch.float16
38
-
39
  pipeline = StableDiffusionPipeline.from_pretrained(
40
  "K2MAR/mon-modele-sd",
41
  torch_dtype=dtype,
42
  safety_checker=None
43
  ).to(DEVICE)
44
-
45
  pipeline.enable_attention_slicing()
46
-
47
  model_loaded = True
48
- print("="*70)
49
  print("✅ Modèle prêt!")
50
- print("="*70 + "\n")
51
  return True
52
  except Exception as e:
53
  print(f"❌ Erreur: {e}")
@@ -58,71 +39,59 @@ def health():
58
  return jsonify({
59
  "status": "ok" if model_loaded else "loading",
60
  "device": DEVICE,
61
- "model_loaded": model_loaded
 
62
  })
63
 
64
  @app.route('/generate', methods=['POST'])
65
  def generate():
66
  if not model_loaded:
67
  return jsonify({"error": "Model not loaded"}), 503
68
-
 
 
 
69
  try:
70
  data = request.get_json()
71
-
72
  if not data or "prompt" not in data:
73
- return jsonify({"error": "Missing 'prompt' in request"}), 400
74
-
75
  prompt = data.get("prompt", "")
76
- steps = int(data.get("steps", 30))
77
  guidance_scale = float(data.get("guidance_scale", 7.5))
78
-
79
  if not prompt:
80
  return jsonify({"error": "Prompt cannot be empty"}), 400
81
-
82
- if steps < 1 or steps > 50:
83
- return jsonify({"error": "Steps must be 1-50"}), 400
84
-
85
- print(f"\n🎨 Génération: {prompt}")
86
-
87
- with torch.no_grad():
88
- image = pipeline(
89
- prompt,
90
- num_inference_steps=steps,
91
- guidance_scale=guidance_scale,
92
- height=512,
93
- width=512
94
- ).images[0]
95
-
96
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
97
- filepath = OUTPUT_DIR / f"generated_{timestamp}.png"
98
- image.save(filepath)
99
-
100
  img_io = io.BytesIO()
101
  image.save(img_io, 'PNG')
102
  img_io.seek(0)
103
-
 
104
  return send_file(img_io, mimetype='image/png')
105
-
106
  except Exception as e:
107
  print(f"❌ Erreur: {str(e)}\n")
108
  return jsonify({"error": str(e)}), 500
109
 
110
  @app.route('/', methods=['GET'])
111
  def home():
112
- return jsonify({
113
- "service": "LoRA Solar Panel Generator API",
114
- "version": "1.0",
115
- "device": DEVICE,
116
- "model_loaded": model_loaded,
117
- "endpoints": {
118
- "health": "GET /health",
119
- "generate": "POST /generate"
120
- }
121
- })
122
 
123
  if __name__ == '__main__':
124
  if not load_model():
125
  exit(1)
126
-
127
- print("\n🚀 Serveur démarrage sur 0.0.0.0:7860\n")
128
  app.run(host='0.0.0.0', port=7860, debug=False, threaded=True)
 
1
  #!/usr/bin/env python3
 
 
 
 
 
2
  import torch
3
+ import threading
4
  from pathlib import Path
5
  from flask import Flask, request, jsonify, send_file
6
  from diffusers import StableDiffusionPipeline
 
9
 
10
  app = Flask(__name__)
11
 
 
12
  OUTPUT_DIR = Path("/app/generated_images")
13
  OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
15
  pipeline = None
16
  model_loaded = False
17
+ lock = threading.Lock()
18
 
19
  def load_model():
 
20
  global pipeline, model_loaded
 
 
 
 
 
21
  try:
22
  print(f"📱 Appareil: {DEVICE}")
 
23
  dtype = torch.float32 if DEVICE == "cpu" else torch.float16
 
24
  pipeline = StableDiffusionPipeline.from_pretrained(
25
  "K2MAR/mon-modele-sd",
26
  torch_dtype=dtype,
27
  safety_checker=None
28
  ).to(DEVICE)
 
29
  pipeline.enable_attention_slicing()
 
30
  model_loaded = True
 
31
  print("✅ Modèle prêt!")
 
32
  return True
33
  except Exception as e:
34
  print(f"❌ Erreur: {e}")
 
39
  return jsonify({
40
  "status": "ok" if model_loaded else "loading",
41
  "device": DEVICE,
42
+ "model_loaded": model_loaded,
43
+ "busy": lock.locked()
44
  })
45
 
46
  @app.route('/generate', methods=['POST'])
47
  def generate():
48
  if not model_loaded:
49
  return jsonify({"error": "Model not loaded"}), 503
50
+
51
+ if lock.locked():
52
+ return jsonify({"error": "Server busy, try again later"}), 503
53
+
54
  try:
55
  data = request.get_json()
 
56
  if not data or "prompt" not in data:
57
+ return jsonify({"error": "Missing 'prompt'"}), 400
58
+
59
  prompt = data.get("prompt", "")
60
+ steps = min(int(data.get("steps", 20)), 30)
61
  guidance_scale = float(data.get("guidance_scale", 7.5))
62
+
63
  if not prompt:
64
  return jsonify({"error": "Prompt cannot be empty"}), 400
65
+
66
+ print(f"\n🎨 Génération: {prompt} ({steps} steps)")
67
+
68
+ with lock:
69
+ with torch.no_grad():
70
+ image = pipeline(
71
+ prompt,
72
+ num_inference_steps=steps,
73
+ guidance_scale=guidance_scale,
74
+ height=512,
75
+ width=512
76
+ ).images[0]
77
+
 
 
 
 
 
 
78
  img_io = io.BytesIO()
79
  image.save(img_io, 'PNG')
80
  img_io.seek(0)
81
+
82
+ print(f"✅ Image générée!\n")
83
  return send_file(img_io, mimetype='image/png')
84
+
85
  except Exception as e:
86
  print(f"❌ Erreur: {str(e)}\n")
87
  return jsonify({"error": str(e)}), 500
88
 
89
  @app.route('/', methods=['GET'])
90
  def home():
91
+ return jsonify({"service": "SD API", "model_loaded": model_loaded, "device": DEVICE})
 
 
 
 
 
 
 
 
 
92
 
93
  if __name__ == '__main__':
94
  if not load_model():
95
  exit(1)
96
+ print("\n🚀 Serveur sur 0.0.0.0:7860\n")
 
97
  app.run(host='0.0.0.0', port=7860, debug=False, threaded=True)