arcsu1 commited on
Commit
4e8e262
·
1 Parent(s): 961e88e

Fix model path to point to .h5 file

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. app.py +69 -49
  3. requirements.txt +2 -3
Dockerfile CHANGED
@@ -16,7 +16,7 @@ RUN pip install --upgrade pip
16
  COPY requirements.txt .
17
 
18
  # Install packages separately to handle large downloads better
19
- RUN pip install --default-timeout=1000 --no-cache-dir fastapi[all] uvicorn[standard] pydantic
20
 
21
  # Install TensorFlow CPU (smaller and faster to download than GPU version)
22
  RUN pip install --default-timeout=1000 --no-cache-dir tensorflow-cpu==2.15.0
@@ -34,4 +34,4 @@ COPY models/ ./models/
34
  EXPOSE 8002
35
 
36
  # Run the application
37
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8002"]
 
16
  COPY requirements.txt .
17
 
18
  # Install packages separately to handle large downloads better
19
+ RUN pip install --default-timeout=1000 --no-cache-dir flask flask-cors
20
 
21
  # Install TensorFlow CPU (smaller and faster to download than GPU version)
22
  RUN pip install --default-timeout=1000 --no-cache-dir tensorflow-cpu==2.15.0
 
34
  EXPOSE 8002
35
 
36
  # Run the application
37
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,78 +1,70 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse
4
- from pydantic import BaseModel
5
  import numpy as np
6
  from keras.models import load_model
7
  from PIL import Image
8
  import io
9
- from typing import Optional
10
 
11
- app = FastAPI(title="Face Generator API", version="1.0.0")
12
 
13
- # Add CORS middleware
14
- app.add_middleware(
15
- CORSMiddleware,
16
- allow_origins=["*"],
17
- allow_credentials=True,
18
- allow_methods=["*"],
19
- allow_headers=["*"],
20
- )
21
 
22
  # Global variables for model
23
- MODEL_PATH = "./models/face-gen-gan"
24
  model = None
25
  latent_dim = None
26
 
27
- # Request models
28
- class FaceGenRequest(BaseModel):
29
- n_samples: Optional[int] = 1
30
- seed: Optional[int] = None
31
-
32
- @app.on_event("startup")
33
- async def load_gan_model():
34
- """Load the GAN model on startup"""
35
  global model, latent_dim
36
- print(f"Loading face generation GAN model from {MODEL_PATH}...")
37
-
38
- model = load_model(MODEL_PATH)
39
- latent_dim = model.input_shape[1]
40
-
41
- print(f"Model loaded successfully! Latent dimension: {latent_dim}")
42
 
43
- @app.get("/")
 
 
 
44
  def root():
45
- return {
46
  "message": "Face Generator API",
47
  "status": "running",
48
  "model": "face-gen-gan",
49
  "latent_dim": latent_dim
50
- }
51
 
52
- @app.get("/health")
53
  def health():
54
- return {
55
  "status": "healthy",
56
  "model_loaded": model is not None,
57
  "latent_dim": latent_dim
58
- }
59
 
60
- @app.post("/generate")
61
- async def generate_faces(request: FaceGenRequest):
62
  """
63
  Generate face images using the GAN model
64
  Returns a PNG image (single face or grid of faces)
65
  """
66
  if model is None:
67
- return {"error": "Model not loaded"}
68
 
69
  try:
 
 
 
 
 
70
  # Validate n_samples
71
- n_samples = max(1, min(request.n_samples, 16)) # Limit to 1-16
72
 
73
  # Set seed if provided
74
- if request.seed is not None:
75
- np.random.seed(request.seed)
76
 
77
  # Generate random latent points
78
  latent_points = np.random.randn(n_samples, latent_dim)
@@ -108,19 +100,47 @@ async def generate_faces(request: FaceGenRequest):
108
  img.save(buf, format='PNG')
109
  buf.seek(0)
110
 
111
- return StreamingResponse(buf, media_type="image/png")
112
 
113
  except Exception as e:
114
- return {"error": str(e)}
115
 
116
- @app.get("/generate-single")
117
- async def generate_single_face(seed: Optional[int] = None):
118
  """
119
  Quick endpoint to generate a single face
120
  """
121
- request = FaceGenRequest(n_samples=1, seed=seed)
122
- return await generate_faces(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
- import uvicorn
126
- uvicorn.run(app, host="0.0.0.0", port=8002)
 
1
+ from flask import Flask, jsonify, request, send_file
2
+ from flask_cors import CORS
 
 
3
  import numpy as np
4
  from keras.models import load_model
5
  from PIL import Image
6
  import io
 
7
 
8
+ app = Flask(__name__)
9
 
10
+ # Enable CORS for all routes
11
+ CORS(app)
 
 
 
 
 
 
12
 
13
  # Global variables for model
14
+ MODEL_PATH = "./models/face-gen-gan/generator_model_100.h5"
15
  model = None
16
  latent_dim = None
17
 
18
+ def load_gan_model():
19
+ """Load the GAN model"""
 
 
 
 
 
 
20
  global model, latent_dim
21
+ if model is None:
22
+ print(f"Loading face generation GAN model from {MODEL_PATH}...")
23
+ model = load_model(MODEL_PATH)
24
+ latent_dim = model.input_shape[1]
25
+ print(f"Model loaded successfully! Latent dimension: {latent_dim}")
 
26
 
27
+ # Load model on startup
28
+ load_gan_model()
29
+
30
+ @app.route("/")
31
  def root():
32
+ return jsonify({
33
  "message": "Face Generator API",
34
  "status": "running",
35
  "model": "face-gen-gan",
36
  "latent_dim": latent_dim
37
+ })
38
 
39
+ @app.route("/health")
40
  def health():
41
+ return jsonify({
42
  "status": "healthy",
43
  "model_loaded": model is not None,
44
  "latent_dim": latent_dim
45
+ })
46
 
47
+ @app.route("/generate", methods=["POST"])
48
+ def generate_faces():
49
  """
50
  Generate face images using the GAN model
51
  Returns a PNG image (single face or grid of faces)
52
  """
53
  if model is None:
54
+ return jsonify({"error": "Model not loaded"}), 500
55
 
56
  try:
57
+ # Get request data
58
+ data = request.get_json() or {}
59
+ n_samples = data.get("n_samples", 1)
60
+ seed = data.get("seed", None)
61
+
62
  # Validate n_samples
63
+ n_samples = max(1, min(int(n_samples), 16)) # Limit to 1-16
64
 
65
  # Set seed if provided
66
+ if seed is not None:
67
+ np.random.seed(int(seed))
68
 
69
  # Generate random latent points
70
  latent_points = np.random.randn(n_samples, latent_dim)
 
100
  img.save(buf, format='PNG')
101
  buf.seek(0)
102
 
103
+ return send_file(buf, mimetype='image/png')
104
 
105
  except Exception as e:
106
+ return jsonify({"error": str(e)}), 500
107
 
108
+ @app.route("/generate-single")
109
+ def generate_single_face():
110
  """
111
  Quick endpoint to generate a single face
112
  """
113
+ seed = request.args.get('seed', None)
114
+
115
+ if model is None:
116
+ return jsonify({"error": "Model not loaded"}), 500
117
+
118
+ try:
119
+ # Set seed if provided
120
+ if seed is not None:
121
+ np.random.seed(int(seed))
122
+
123
+ # Generate random latent points
124
+ latent_points = np.random.randn(1, latent_dim)
125
+
126
+ # Generate images
127
+ generated_images = model.predict(latent_points, verbose=0)
128
+
129
+ # Scale from [-1, 1] to [0, 255]
130
+ generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
131
+
132
+ # Single image
133
+ img = Image.fromarray(generated_images[0])
134
+
135
+ # Convert to bytes
136
+ buf = io.BytesIO()
137
+ img.save(buf, format='PNG')
138
+ buf.seek(0)
139
+
140
+ return send_file(buf, mimetype='image/png')
141
+
142
+ except Exception as e:
143
+ return jsonify({"error": str(e)}), 500
144
 
145
  if __name__ == "__main__":
146
+ app.run(host="0.0.0.0", port=8002, debug=False)
 
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
- fastapi[all]
2
- uvicorn[standard]
3
  tensorflow-cpu==2.15.0
4
  keras==2.15.0
5
  numpy==1.26.4
6
  pillow
7
- pydantic
 
1
+ flask
2
+ flask-cors
3
  tensorflow-cpu==2.15.0
4
  keras==2.15.0
5
  numpy==1.26.4
6
  pillow