my-cgan-api / app.py
zainebrahem's picture
Upload 2 files
a5e8c4b verified
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from huggingface_hub import hf_hub_download
from PIL import Image
import io
app = FastAPI()
LEAKY_ALPHA = 0.2
Z_DIM = 100
NUM_CLASSES = 32
CLASSES = [
'ain','al','aleff','bb','dal','dha','dhad','fa',
'gaaf','ghain','ha','haa','jeem','kaaf','khaa','la',
'laam','meem','nun','ra','saad','seen','sheen','ta',
'taa','thaa','thal','toot','waw','ya','yaa','zay'
]
def build_generator():
noise_input = tf.keras.Input(shape=(Z_DIM,), dtype='float32', name="noise")
label_input = tf.keras.Input(shape=(1,), dtype='int32', name="label")
label_emb = layers.Embedding(NUM_CLASSES, Z_DIM)(label_input)
label_emb = layers.Flatten()(label_emb)
x = layers.Concatenate()([noise_input, label_emb])
x = layers.Dense(4 * 4 * 512, use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(LEAKY_ALPHA)(x)
x = layers.Reshape((4, 4, 512))(x)
for filters in [512, 256, 128, 64, 64, 32]:
x = layers.Conv2DTranspose(filters, 4, strides=2, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(LEAKY_ALPHA)(x)
x = layers.Conv2D(1, 3, padding='same')(x)
output = layers.Activation('tanh', dtype='float32')(x)
return tf.keras.Model([noise_input, label_input], output, name="generator")
# Load model once at startup
print("Loading generator weights...")
weights_path = hf_hub_download(
repo_id="zainebrahem/my-cgan",
filename="generator.weights.h5"
)
generator = build_generator()
generator.load_weights(weights_path)
print("Generator ready!")
class GenerateRequest(BaseModel):
label: int # 0 to 31
@app.get("/")
def root():
return {"status": "CGAN API is running", "classes": CLASSES}
@app.post("/generate")
def generate(req: GenerateRequest):
if req.label < 0 or req.label >= NUM_CLASSES:
return {"error": f"Label must be between 0 and {NUM_CLASSES - 1}"}
noise = tf.random.normal([1, Z_DIM])
label_tensor = tf.constant([[req.label]], dtype=tf.int32)
image = generator([noise, label_tensor], training=False)
image = (image[0].numpy() * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
img_pil = Image.fromarray(image[:, :, 0], mode='L')
buf = io.BytesIO()
img_pil.save(buf, format='PNG')
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
@app.get("/classes")
def get_classes():
return {i: c for i, c in enumerate(CLASSES)}