kunalpro379 commited on
Commit
476965d
·
verified ·
1 Parent(s): 79ec3e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -51
app.py CHANGED
@@ -1,51 +1,71 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse, StreamingResponse
3
- import numpy as np
4
- from tensorflow.keras.models import load_model
5
- from PIL import Image
6
- import io
7
- import os
8
-
9
- app = FastAPI(title="GAN Image Generator API")
10
-
11
- # Load model
12
- model_path = os.getenv("MODEL_PATH", "generator_final.h5")
13
- generator = load_model(model_path)
14
-
15
- def generate_image(noise_dim=100):
16
- """Generate image from random noise"""
17
- z = np.random.randn(1, noise_dim, 1, 1).astype(np.float32)
18
- fake_image = generator.predict(z)
19
- fake_image = (fake_image.squeeze() * 255).astype(np.uint8)
20
- return fake_image[..., :3]
21
-
22
- @app.get("/")
23
- def read_root():
24
- return {"message": "GAN Image Generator API"}
25
-
26
- @app.get("/generate-random")
27
- async def generate_random_image():
28
- """Endpoint to generate random image"""
29
- image_array = generate_image()
30
- img = Image.fromarray(image_array)
31
-
32
- # Convert to bytes
33
- img_byte_arr = io.BytesIO()
34
- img.save(img_byte_arr, format='PNG')
35
- img_byte_arr.seek(0)
36
-
37
- return StreamingResponse(img_byte_arr, media_type="image/png")
38
-
39
- @app.post("/generate-from-sketch")
40
- async def generate_from_sketch(file: UploadFile = File(...)):
41
- """Endpoint to generate from sketch"""
42
- # Process your sketch here (add your sketch processing logic)
43
- # For now just returns a random image
44
- image_array = generate_image()
45
- img = Image.fromarray(image_array)
46
-
47
- img_byte_arr = io.BytesIO()
48
- img.save(img_byte_arr, format='PNG')
49
- img_byte_arr.seek(0)
50
-
51
- return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ import numpy as np
5
+ from tensorflow.keras.models import load_model
6
+ from PIL import Image
7
+ import io
8
+ import os
9
+ import cv2
10
+
11
+ app = FastAPI(title="GAN Image Generator API")
12
+
13
+ # Add CORS middleware to allow React frontend to connect
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"], # Allows all origins - adjust for production!
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ # Load model
23
+ model_path = os.getenv("MODEL_PATH", "generator_final.h5")
24
+ generator = load_model(model_path)
25
+
26
+ def preprocess_sketch(image_bytes):
27
+ """Process uploaded sketch image for the GAN"""
28
+ try:
29
+ # Convert bytes to PIL Image
30
+ img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
31
+ img = np.array(img)
32
+
33
+ # Convert to grayscale
34
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
35
+
36
+ # Resize to model's expected input size (adjust dimensions as needed)
37
+ resized = cv2.resize(gray, (256, 256))
38
+
39
+ # Normalize pixel values to [-1, 1] range
40
+ normalized = (resized.astype(np.float32) / 127.5) - 1.0
41
+
42
+ # Add batch and channel dimensions
43
+ processed = np.expand_dims(normalized, axis=[0, -1])
44
+
45
+ return processed
46
+ except Exception as e:
47
+ raise ValueError(f"Image processing failed: {str(e)}")
48
+
49
+ @app.post("/generate-from-sketch")
50
+ async def generate_from_sketch(file: UploadFile = File(...)):
51
+ try:
52
+ # Read uploaded file
53
+ contents = await file.read()
54
+
55
+ # Process sketch
56
+ processed_sketch = preprocess_sketch(contents)
57
+
58
+ # Generate image (modify this to use your actual GAN prediction)
59
+ generated = generator.predict(processed_sketch)
60
+ generated = (generated.squeeze() * 255).astype(np.uint8)
61
+
62
+ # Convert to bytes
63
+ img = Image.fromarray(generated)
64
+ img_byte_arr = io.BytesIO()
65
+ img.save(img_byte_arr, format='PNG')
66
+ img_byte_arr.seek(0)
67
+
68
+ return StreamingResponse(img_byte_arr, media_type="image/png")
69
+
70
+ except Exception as e:
71
+ raise HTTPException(status_code=400, detail=str(e))