S2Fnew / app.py
kunalpro379's picture
Update app.py
5de14f3 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import numpy as np
from tensorflow.keras.models import load_model
from PIL import Image
import io
import os
app = FastAPI(title="GAN Image Generator API")
# Add CORS middleware to allow React frontend to connect
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins - adjust for production!
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
model_path = os.getenv("MODEL_PATH", "generator_final.h5")
generator = load_model(model_path)
def preprocess_sketch(image_bytes):
"""Process uploaded sketch image for the GAN"""
try:
# Convert bytes to PIL Image and ensure it's RGB
img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Resize to model's expected input size
img = img.resize((256, 256))
# Convert to numpy array and normalize
img_array = np.array(img).astype(np.float32) / 255.0
# Ensure the array has 3 channels
if len(img_array.shape) == 2:
img_array = np.stack((img_array,) * 3, axis=-1)
# Add batch dimension (1, 256, 256, 3)
processed = np.expand_dims(img_array, axis=0)
return processed
except Exception as e:
raise ValueError(f"Image processing failed: {str(e)}")
@app.get("/")
async def health_check():
"""Health check endpoint"""
return JSONResponse(content={"status": "API is running", "model": "GAN Image Generator"})
@app.post("/generate-from-sketch")
async def generate_from_sketch(file: UploadFile = File(...)):
try:
# Read uploaded file
contents = await file.read()
# Process sketch
processed_sketch = preprocess_sketch(contents)
# Generate image using GAN
generated = generator.predict(processed_sketch)
generated = np.clip(generated[0], 0, 1) * 255
generated = generated.astype(np.uint8)
# Convert to bytes
img = Image.fromarray(generated)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))