Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras import layers | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import io | |
| app = FastAPI() | |
| LEAKY_ALPHA = 0.2 | |
| Z_DIM = 100 | |
| NUM_CLASSES = 32 | |
| CLASSES = [ | |
| 'ain','al','aleff','bb','dal','dha','dhad','fa', | |
| 'gaaf','ghain','ha','haa','jeem','kaaf','khaa','la', | |
| 'laam','meem','nun','ra','saad','seen','sheen','ta', | |
| 'taa','thaa','thal','toot','waw','ya','yaa','zay' | |
| ] | |
| def build_generator(): | |
| noise_input = tf.keras.Input(shape=(Z_DIM,), dtype='float32', name="noise") | |
| label_input = tf.keras.Input(shape=(1,), dtype='int32', name="label") | |
| label_emb = layers.Embedding(NUM_CLASSES, Z_DIM)(label_input) | |
| label_emb = layers.Flatten()(label_emb) | |
| x = layers.Concatenate()([noise_input, label_emb]) | |
| x = layers.Dense(4 * 4 * 512, use_bias=False)(x) | |
| x = layers.BatchNormalization()(x) | |
| x = layers.LeakyReLU(LEAKY_ALPHA)(x) | |
| x = layers.Reshape((4, 4, 512))(x) | |
| for filters in [512, 256, 128, 64, 64, 32]: | |
| x = layers.Conv2DTranspose(filters, 4, strides=2, padding='same', use_bias=False)(x) | |
| x = layers.BatchNormalization()(x) | |
| x = layers.LeakyReLU(LEAKY_ALPHA)(x) | |
| x = layers.Conv2D(1, 3, padding='same')(x) | |
| output = layers.Activation('tanh', dtype='float32')(x) | |
| return tf.keras.Model([noise_input, label_input], output, name="generator") | |
| # Load model once at startup | |
| print("Loading generator weights...") | |
| weights_path = hf_hub_download( | |
| repo_id="zainebrahem/my-cgan", | |
| filename="generator.weights.h5" | |
| ) | |
| generator = build_generator() | |
| generator.load_weights(weights_path) | |
| print("Generator ready!") | |
| class GenerateRequest(BaseModel): | |
| label: int # 0 to 31 | |
| def root(): | |
| return {"status": "CGAN API is running", "classes": CLASSES} | |
| def generate(req: GenerateRequest): | |
| if req.label < 0 or req.label >= NUM_CLASSES: | |
| return {"error": f"Label must be between 0 and {NUM_CLASSES - 1}"} | |
| noise = tf.random.normal([1, Z_DIM]) | |
| label_tensor = tf.constant([[req.label]], dtype=tf.int32) | |
| image = generator([noise, label_tensor], training=False) | |
| image = (image[0].numpy() * 127.5 + 127.5).clip(0, 255).astype(np.uint8) | |
| img_pil = Image.fromarray(image[:, :, 0], mode='L') | |
| buf = io.BytesIO() | |
| img_pil.save(buf, format='PNG') | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| def get_classes(): | |
| return {i: c for i, c in enumerate(CLASSES)} | |