Spaces:
Sleeping
Sleeping
popboat1
Fix Tensor shape parsing bug, implement SVG icons, and migrate video streaming to robust HTTP polling
656ae3a | from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| import base64 | |
| import math | |
| import os | |
| gpus = tf.config.list_physical_devices('GPU') | |
| if gpus: | |
| try: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| except RuntimeError as e: | |
| print(e) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| class SafeDense(tf.keras.layers.Dense): | |
| def __init__(self, **kwargs): | |
| kwargs.pop('quantization_config', None) | |
| super().__init__(**kwargs) | |
| class SafeConv2D(tf.keras.layers.Conv2D): | |
| def __init__(self, **kwargs): | |
| kwargs.pop('quantization_config', None) | |
| super().__init__(**kwargs) | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.path.join(BASE_DIR, 'alexnet_cifar10_keras.h5') | |
| model = tf.keras.models.load_model( | |
| MODEL_PATH, | |
| custom_objects={ | |
| 'Dense': SafeDense, | |
| 'Conv2D': SafeConv2D | |
| } | |
| ) | |
| conv_layers = [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)] | |
| feature_extractor = tf.keras.Model(inputs=model.inputs, outputs=[layer.output for layer in conv_layers]) | |
| CIFAR10_CLASSES = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', | |
| 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'] | |
| def generate_feature_grid(feature_map, max_features=64): | |
| if len(feature_map.shape) == 4: | |
| feature_map = feature_map[0] | |
| height, width, channels = feature_map.shape | |
| num_features = min(channels, max_features) | |
| grid_size = math.ceil(math.sqrt(num_features)) | |
| grid_image = np.zeros((grid_size * height, grid_size * width), dtype=np.float32) | |
| for i in range(num_features): | |
| row = i // grid_size | |
| col = i % grid_size | |
| channel_img = feature_map[:, :, i] | |
| channel_img -= channel_img.min() | |
| if channel_img.max() > 0: | |
| channel_img /= channel_img.max() | |
| channel_img *= 255.0 | |
| y_start, y_end = row * height, (row + 1) * height | |
| x_start, x_end = col * width, (col + 1) * width | |
| grid_image[y_start:y_end, x_start:x_end] = channel_img | |
| grid_image = np.uint8(grid_image) | |
| colored_grid = cv2.applyColorMap(grid_image, cv2.COLORMAP_VIRIDIS) | |
| b_channel, g_channel, r_channel = cv2.split(colored_grid) | |
| alpha_channel = grid_image | |
| transparent_grid = cv2.merge((b_channel, g_channel, r_channel, alpha_channel)) | |
| _, buffer = cv2.imencode('.png', transparent_grid) | |
| return base64.b64encode(buffer).decode('utf-8') | |
| async def predict_image(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| img_resized = cv2.resize(img, (227, 227)) | |
| img_normalized = img_resized.astype(np.float32) / 255.0 | |
| img_batch = np.expand_dims(img_normalized, axis=0) | |
| activations = feature_extractor(img_batch, training=False) | |
| predictions = model(img_batch, training=False) | |
| class_idx = np.argmax(predictions[0].numpy()) | |
| layer_data = [] | |
| for i, activation in enumerate(activations): | |
| b64_image = generate_feature_grid(activation.numpy()) | |
| clean_shape = [int(dim) for dim in activation.shape[1:]] | |
| layer_data.append({ | |
| "layer_index": i + 1, | |
| "shape": clean_shape, | |
| "texture_b64": f"data:image/png;base64,{b64_image}" | |
| }) | |
| return { | |
| "prediction": CIFAR10_CLASSES[class_idx], | |
| "layers": layer_data | |
| } | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") |