Spaces:
Running
Running
| import sys | |
| import os | |
| import io | |
| import base64 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| # Add current directory to path so HF Space finds it | |
| import sys | |
| import os | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.append(current_dir) | |
| from omegaconf import OmegaConf | |
| from src.model import build_model | |
| from src.attention_viz import attention_rollout_full, make_overlay | |
| from src.dataset import QUESTION_GROUPS | |
| from torchvision import transforms | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = None | |
| cfg = None | |
| transform = None | |
| def load_model(): | |
| global model, cfg, transform | |
| print("Loading configuration...") | |
| base_cfg = OmegaConf.load(os.path.join(current_dir, "configs/base.yaml")) | |
| # We load the full train config | |
| try: | |
| exp_cfg = OmegaConf.load(os.path.join(current_dir, "configs/full_train.yaml")) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| except: | |
| cfg = base_cfg | |
| print("Building model...") | |
| model = build_model(cfg).to(device) | |
| ckpt_path = os.path.join(current_dir, "best_full_train.pt") | |
| if os.path.exists(ckpt_path): | |
| print(f"Loading checkpoint from {ckpt_path}") | |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) | |
| model.load_state_dict(ckpt["model_state"]) | |
| else: | |
| print(f"WARNING: Checkpoint not found at {ckpt_path}") | |
| model.eval() | |
| # Galaxy Zoo image transform: resize, crop, center, normalize | |
| # Assuming standard Imagenet + ViT transforms for 224x224 | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| async def predict(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Transform image | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| with torch.amp.autocast("cuda", enabled=True): | |
| logits = model(img_tensor) | |
| # Get attention weights | |
| layers = model.get_all_attention_weights() | |
| # Process predictions mapping | |
| predictions = logits[0].cpu().numpy() | |
| results = {} | |
| # In proper evaluation, hierarchical softmax is applied per question group | |
| import torch.nn.functional as F | |
| probs = logits.detach().cpu().clone() | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| probs[:, start:end] = F.softmax(probs[:, start:end], dim=-1) | |
| probs_np = probs[0].numpy() | |
| for q_name, (start, end) in QUESTION_GROUPS.items(): | |
| results[q_name] = probs_np[start:end].tolist() | |
| # Generate Attention Heatmap Overlay | |
| if layers is not None: | |
| # attention_rollout_full expects list of [B, H, N+1, N+1] | |
| all_layer_attns = [l.cpu() for l in layers] | |
| rollout_map = attention_rollout_full(all_layer_attns, patch_size=16, image_size=224)[0] | |
| # original image numpy for overlay (denormalised size) | |
| original_img_np = np.array(image.resize((224, 224))) | |
| overlay = make_overlay(original_img_np, rollout_map, alpha=0.5, colormap="inferno") | |
| # Encode to base64 | |
| overlay_img = Image.fromarray(overlay) | |
| buffered = io.BytesIO() | |
| overlay_img.save(buffered, format="PNG") | |
| heatmap_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| else: | |
| heatmap_base64 = None | |
| return { | |
| "predictions": results, | |
| "heatmap": heatmap_base64 | |
| } | |