Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| from PIL import Image | |
| import io | |
| import os | |
| app = FastAPI(title="GAN Image Generator API") | |
| # Add CORS middleware to allow React frontend to connect | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins - adjust for production! | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load model | |
| model_path = os.getenv("MODEL_PATH", "generator_final.h5") | |
| generator = load_model(model_path) | |
| def preprocess_sketch(image_bytes): | |
| """Process uploaded sketch image for the GAN""" | |
| try: | |
| # Convert bytes to PIL Image and ensure it's RGB | |
| img = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| # Resize to model's expected input size | |
| img = img.resize((256, 256)) | |
| # Convert to numpy array and normalize | |
| img_array = np.array(img).astype(np.float32) / 255.0 | |
| # Ensure the array has 3 channels | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack((img_array,) * 3, axis=-1) | |
| # Add batch dimension (1, 256, 256, 3) | |
| processed = np.expand_dims(img_array, axis=0) | |
| return processed | |
| except Exception as e: | |
| raise ValueError(f"Image processing failed: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return JSONResponse(content={"status": "API is running", "model": "GAN Image Generator"}) | |
| async def generate_from_sketch(file: UploadFile = File(...)): | |
| try: | |
| # Read uploaded file | |
| contents = await file.read() | |
| # Process sketch | |
| processed_sketch = preprocess_sketch(contents) | |
| # Generate image using GAN | |
| generated = generator.predict(processed_sketch) | |
| generated = np.clip(generated[0], 0, 1) * 255 | |
| generated = generated.astype(np.uint8) | |
| # Convert to bytes | |
| img = Image.fromarray(generated) | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| return StreamingResponse(img_byte_arr, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |