Spaces:
Sleeping
Sleeping
| import uvicorn | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, models, applications, Input, regularizers | |
| import numpy as np | |
| import joblib | |
| import cv2 | |
| import base64 | |
| from PIL import Image | |
| import io | |
| import json | |
| # --- 0. SETUP --- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- 1. DEFINE ARCHITECTURE (Exact Match to Training) --- | |
| def cbam_block(x, ratio=8): | |
| channel = x.shape[-1] | |
| # 1. Channel Attention | |
| l1 = layers.Dense(channel // ratio, activation="relu", use_bias=False) | |
| l2 = layers.Dense(channel, use_bias=False) | |
| x_avg = l2(l1(layers.GlobalAveragePooling2D()(x))) | |
| x_max = l2(l1(layers.GlobalMaxPooling2D()(x))) | |
| x_att = layers.Activation('sigmoid')(layers.Add()([x_avg, x_max])) | |
| x_att = layers.Reshape((1, 1, channel))(x_att) | |
| x = layers.Multiply()([x, x_att]) | |
| # 2. Spatial Attention (FIXED: Uses Lambda to match training shapes) | |
| # This reduces Channels to 1, resulting in (H, W, 1) | |
| avg_pool = layers.Lambda(lambda t: tf.reduce_mean(t, axis=-1, keepdims=True))(x) | |
| max_pool = layers.Lambda(lambda t: tf.reduce_max(t, axis=-1, keepdims=True))(x) | |
| concat = layers.Concatenate(axis=-1)([avg_pool, max_pool]) # Shape (H, W, 2) | |
| conv = layers.Conv2D(1, 7, padding='same', activation='sigmoid', use_bias=False)(concat) | |
| return layers.Multiply()([x, conv]) | |
| class TransformerBlock(layers.Layer): | |
| def __init__(self, embed_dim=64, num_heads=4, ff_dim=128, rate=0.1, **kwargs): | |
| super().__init__(**kwargs) | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.ff_dim = ff_dim | |
| self.rate = rate | |
| self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) | |
| self.ffn = models.Sequential([layers.Dense(ff_dim, "relu"), layers.Dense(embed_dim)]) | |
| self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) | |
| self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) | |
| self.dropout1 = layers.Dropout(rate) | |
| self.dropout2 = layers.Dropout(rate) | |
| def call(self, inputs, training=True): | |
| out1 = self.layernorm1(inputs + self.dropout1(self.att(inputs, inputs), training=training)) | |
| return self.layernorm2(out1 + self.dropout2(self.ffn(out1), training=training)) | |
| def build_model_local(): | |
| # Visual Branch | |
| img_in = Input(shape=(224, 224, 3), name='image_input') | |
| base = applications.EfficientNetB0(include_top=False, weights='imagenet', input_tensor=img_in) | |
| for layer in base.layers[:-20]: layer.trainable = False | |
| x = cbam_block(base.output) | |
| x = layers.GlobalAveragePooling2D()(x) | |
| img_vec = layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.0001))(x) | |
| # Tabular Branch | |
| input_dim = 14 | |
| tab_in = Input(shape=(input_dim,), name='tabular_input') | |
| x = layers.Dense(input_dim * 64)(tab_in) | |
| x = layers.Reshape((input_dim, 64))(x) | |
| x = TransformerBlock(embed_dim=64, num_heads=4, ff_dim=128, rate=0.3)(x) | |
| x = layers.GlobalAveragePooling1D()(x) | |
| tab_vec = layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.0001))(x) | |
| # Fusion | |
| combined = layers.Concatenate()([img_vec, tab_vec]) | |
| z = layers.Dense(64, activation='relu')(combined) | |
| z = layers.Dropout(0.4)(z) | |
| out = layers.Dense(1, activation='sigmoid', name='diagnosis')(z) | |
| model = models.Model(inputs=[img_in, tab_in], outputs=out) | |
| return model | |
| # --- 2. LOAD ASSETS --- | |
| print("⏳ Loading Assets...") | |
| model = None | |
| scaler = None | |
| try: | |
| # A. Scaler | |
| scaler = joblib.load("scaler.pkl") | |
| print(" ✅ Scaler Loaded.") | |
| # B. Model | |
| model = build_model_local() | |
| # Now the shapes match (7,7,2,1) -> (7,7,2,1) | |
| model.load_weights("autism_model.keras") | |
| print(" ✅ Model Weights Loaded.") | |
| except Exception as e: | |
| print(f"\n❌ CRITICAL ERROR: {e}\n") | |
| # --- 3. HELPER FUNCTIONS --- | |
| def generate_gradcam(img_array): | |
| if model is None: return np.zeros((224,224)) | |
| # Robust Layer Detection (Looking for 4D output) | |
| target_layer = None | |
| for layer in reversed(model.layers): | |
| try: | |
| if len(layer.output.shape) == 4: | |
| target_layer = layer.name | |
| break | |
| except: continue | |
| grad_model = tf.keras.models.Model( | |
| inputs=model.inputs, | |
| outputs=[model.get_layer(target_layer).output, model.output] | |
| ) | |
| with tf.GradientTape() as tape: | |
| img_tensor = tf.cast(img_array, tf.float32) | |
| dummy_tab = tf.zeros((1, 14), dtype=tf.float32) | |
| inputs = [img_tensor, dummy_tab] | |
| conv_out, preds = grad_model(inputs) | |
| loss = preds[:, 0] | |
| grads = tape.gradient(loss, conv_out) | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| heatmap = conv_out[0] @ pooled_grads[..., tf.newaxis] | |
| heatmap = tf.squeeze(heatmap) | |
| heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) | |
| return heatmap.numpy() | |
| async def predict(file: UploadFile = File(...), patient_data: str = Form(...)): | |
| if model is None or scaler is None: | |
| return {"error": "Server initialization failed."} | |
| # Process Image | |
| img_bytes = await file.read() | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| image = image.resize((224, 224)) | |
| img_array = np.array(image) | |
| img_input = np.expand_dims(img_array / 255.0, axis=0) | |
| # Process Tabular | |
| data = json.loads(patient_data) | |
| features = [ | |
| data['A1'], data['A2'], data['A3'], data['A4'], data['A5'], | |
| data['A6'], data['A7'], data['A8'], data['A9'], data['A10'], | |
| data['Age'], data['Sex'], data['Jaundice'], data['FamHx'] | |
| ] | |
| tab_input = scaler.transform(np.array([features])) | |
| # Predict | |
| prediction = model.predict([img_input, tab_input]) | |
| risk_score = float(prediction[0][0]) | |
| # XAI | |
| heatmap = generate_gradcam(img_input) | |
| heatmap_uint8 = np.uint8(255 * heatmap) | |
| jet = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) | |
| jet = cv2.resize(jet, (224, 224)) | |
| original_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| superimposed = cv2.addWeighted(original_cv, 0.6, jet, 0.4, 0) | |
| _, buffer = cv2.imencode('.jpg', superimposed) | |
| xai_b64 = base64.b64encode(buffer).decode('utf-8') | |
| return { | |
| "risk_score": risk_score, | |
| "diagnosis": "Autistic" if risk_score > 0.40 else "Non-Autistic", | |
| "xai_image": f"data:image/jpeg;base64,{xai_b64}" | |
| } | |
| if __name__ == "__main__": | |
| # Hugging Face Spaces requires port 7860! | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |