from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse import numpy as np from tensorflow.keras.models import load_model from PIL import Image import io import os app = FastAPI(title="GAN Image Generator API") # Add CORS middleware to allow React frontend to connect app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins - adjust for production! allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model model_path = os.getenv("MODEL_PATH", "generator_final.h5") generator = load_model(model_path) def preprocess_sketch(image_bytes): """Process uploaded sketch image for the GAN""" try: # Convert bytes to PIL Image and ensure it's RGB img = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Resize to model's expected input size img = img.resize((256, 256)) # Convert to numpy array and normalize img_array = np.array(img).astype(np.float32) / 255.0 # Ensure the array has 3 channels if len(img_array.shape) == 2: img_array = np.stack((img_array,) * 3, axis=-1) # Add batch dimension (1, 256, 256, 3) processed = np.expand_dims(img_array, axis=0) return processed except Exception as e: raise ValueError(f"Image processing failed: {str(e)}") @app.get("/") async def health_check(): """Health check endpoint""" return JSONResponse(content={"status": "API is running", "model": "GAN Image Generator"}) @app.post("/generate-from-sketch") async def generate_from_sketch(file: UploadFile = File(...)): try: # Read uploaded file contents = await file.read() # Process sketch processed_sketch = preprocess_sketch(contents) # Generate image using GAN generated = generator.predict(processed_sketch) generated = np.clip(generated[0], 0, 1) * 255 generated = generated.astype(np.uint8) # Convert to bytes img = Image.fromarray(generated) img_byte_arr = io.BytesIO() img.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return StreamingResponse(img_byte_arr, media_type="image/png") except Exception as e: raise HTTPException(status_code=400, detail=str(e))