from fastapi import FastAPI from fastapi.responses import StreamingResponse from pydantic import BaseModel import numpy as np import tensorflow as tf from tensorflow.keras import layers from huggingface_hub import hf_hub_download from PIL import Image import io app = FastAPI() LEAKY_ALPHA = 0.2 Z_DIM = 100 NUM_CLASSES = 32 CLASSES = [ 'ain','al','aleff','bb','dal','dha','dhad','fa', 'gaaf','ghain','ha','haa','jeem','kaaf','khaa','la', 'laam','meem','nun','ra','saad','seen','sheen','ta', 'taa','thaa','thal','toot','waw','ya','yaa','zay' ] def build_generator(): noise_input = tf.keras.Input(shape=(Z_DIM,), dtype='float32', name="noise") label_input = tf.keras.Input(shape=(1,), dtype='int32', name="label") label_emb = layers.Embedding(NUM_CLASSES, Z_DIM)(label_input) label_emb = layers.Flatten()(label_emb) x = layers.Concatenate()([noise_input, label_emb]) x = layers.Dense(4 * 4 * 512, use_bias=False)(x) x = layers.BatchNormalization()(x) x = layers.LeakyReLU(LEAKY_ALPHA)(x) x = layers.Reshape((4, 4, 512))(x) for filters in [512, 256, 128, 64, 64, 32]: x = layers.Conv2DTranspose(filters, 4, strides=2, padding='same', use_bias=False)(x) x = layers.BatchNormalization()(x) x = layers.LeakyReLU(LEAKY_ALPHA)(x) x = layers.Conv2D(1, 3, padding='same')(x) output = layers.Activation('tanh', dtype='float32')(x) return tf.keras.Model([noise_input, label_input], output, name="generator") # Load model once at startup print("Loading generator weights...") weights_path = hf_hub_download( repo_id="zainebrahem/my-cgan", filename="generator.weights.h5" ) generator = build_generator() generator.load_weights(weights_path) print("Generator ready!") class GenerateRequest(BaseModel): label: int # 0 to 31 @app.get("/") def root(): return {"status": "CGAN API is running", "classes": CLASSES} @app.post("/generate") def generate(req: GenerateRequest): if req.label < 0 or req.label >= NUM_CLASSES: return {"error": f"Label must be between 0 and {NUM_CLASSES - 1}"} noise = tf.random.normal([1, Z_DIM]) label_tensor = tf.constant([[req.label]], dtype=tf.int32) image = generator([noise, label_tensor], training=False) image = (image[0].numpy() * 127.5 + 127.5).clip(0, 255).astype(np.uint8) img_pil = Image.fromarray(image[:, :, 0], mode='L') buf = io.BytesIO() img_pil.save(buf, format='PNG') buf.seek(0) return StreamingResponse(buf, media_type="image/png") @app.get("/classes") def get_classes(): return {i: c for i, c in enumerate(CLASSES)}