File size: 2,424 Bytes
476965d
 
5de14f3
476965d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b73c294
476965d
 
b73c294
 
 
 
 
 
 
 
 
 
 
 
476965d
 
 
 
5de14f3
 
 
 
 
476965d
 
 
 
 
b73c294
476965d
 
b73c294
 
476965d
b73c294
 
 
476965d
 
 
 
 
b73c294
476965d
b73c294
476965d
b73c294
 
5de14f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import numpy as np
from tensorflow.keras.models import load_model
from PIL import Image
import io
import os

app = FastAPI(title="GAN Image Generator API")

# Add CORS middleware to allow React frontend to connect
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins - adjust for production!
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load model
model_path = os.getenv("MODEL_PATH", "generator_final.h5")
generator = load_model(model_path)

def preprocess_sketch(image_bytes):
    """Process uploaded sketch image for the GAN"""
    try:
        # Convert bytes to PIL Image and ensure it's RGB
        img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        
        # Resize to model's expected input size
        img = img.resize((256, 256))

        # Convert to numpy array and normalize
        img_array = np.array(img).astype(np.float32) / 255.0

        # Ensure the array has 3 channels
        if len(img_array.shape) == 2:
            img_array = np.stack((img_array,) * 3, axis=-1)

        # Add batch dimension (1, 256, 256, 3)
        processed = np.expand_dims(img_array, axis=0)
        return processed
    except Exception as e:
        raise ValueError(f"Image processing failed: {str(e)}")

@app.get("/")
async def health_check():
    """Health check endpoint"""
    return JSONResponse(content={"status": "API is running", "model": "GAN Image Generator"})

@app.post("/generate-from-sketch")
async def generate_from_sketch(file: UploadFile = File(...)):
    try:
        # Read uploaded file
        contents = await file.read()

        # Process sketch
        processed_sketch = preprocess_sketch(contents)

        # Generate image using GAN
        generated = generator.predict(processed_sketch)
        generated = np.clip(generated[0], 0, 1) * 255
        generated = generated.astype(np.uint8)

        # Convert to bytes
        img = Image.fromarray(generated)
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)

        return StreamingResponse(img_byte_arr, media_type="image/png")

    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))