| | import gradio as gr |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from transformers import AutoModel |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) |
| | model = model.eval().to(device) |
| |
|
| |
|
| | def calculate_ctr(mask): |
| | lungs = np.zeros_like(mask, dtype=np.uint8) |
| | lungs[(mask == 1) | (mask == 2)] = 1 |
| | heart = (mask == 3).astype("uint8") |
| |
|
| | lung_y, lung_x = np.where(lungs == 1) |
| | heart_y, heart_x = np.where(heart == 1) |
| |
|
| | if lung_x.size == 0 or heart_x.size == 0: |
| | return None, None, None, None, None |
| |
|
| | thorax_left = int(lung_x.min()) |
| | thorax_right = int(lung_x.max()) |
| | heart_left = int(heart_x.min()) |
| | heart_right = int(heart_x.max()) |
| |
|
| | lung_range = thorax_right - thorax_left |
| | heart_range = heart_right - heart_left |
| | if lung_range == 0: |
| | ctr = None |
| | else: |
| | ctr = float(heart_range / lung_range) |
| |
|
| | return ctr, thorax_left, thorax_right, heart_left, heart_right |
| |
|
| |
|
| | def _run_model(image): |
| | """Shared logic: from PIL image -> (img_gray, mask, view_idx, age, female_prob, coords...)""" |
| | img = np.array(image.convert("L")) |
| | h, w = img.shape[:2] |
| |
|
| | x = model.preprocess(img) |
| | x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).float() |
| |
|
| | with torch.inference_mode(): |
| | out = model(x.to(device)) |
| |
|
| | mask_small = out["mask"].argmax(1)[0].cpu().numpy() |
| | mask = cv2.resize(mask_small.astype("uint8"), (w, h), interpolation=cv2.INTER_NEAREST) |
| |
|
| | view_idx = out["view"].argmax(1).item() |
| | age_pred = float(out["age"].item()) |
| | female_prob = float(out["female"].item()) |
| |
|
| | ctr, thorax_left, thorax_right, heart_left, heart_right = calculate_ctr(mask) |
| |
|
| | return ( |
| | img, |
| | mask, |
| | h, |
| | w, |
| | ctr, |
| | thorax_left, |
| | thorax_right, |
| | heart_left, |
| | heart_right, |
| | view_idx, |
| | age_pred, |
| | female_prob, |
| | ) |
| |
|
| |
|
| | |
| |
|
| | def analyze(image): |
| | if image is None: |
| | return None, "No image uploaded." |
| |
|
| | ( |
| | img, |
| | mask, |
| | h, |
| | w, |
| | ctr, |
| | thorax_left, |
| | thorax_right, |
| | heart_left, |
| | heart_right, |
| | view_idx, |
| | age_pred, |
| | female_prob, |
| | ) = _run_model(image) |
| |
|
| | color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| | overlay = color.copy() |
| | overlay[mask == 1] = [0, 255, 0] |
| | overlay[mask == 2] = [0, 128, 255] |
| | overlay[mask == 3] = [255, 0, 0] |
| | blended = cv2.addWeighted(color, 0.7, overlay, 0.3, 0) |
| |
|
| | view_map = {0: "AP", 1: "PA", 2: "lateral"} |
| | view = view_map.get(view_idx, "unknown") |
| |
|
| | lines = [] |
| | if ctr is not None: |
| | lines.append(f"CTR: {ctr:.2f}") |
| | else: |
| | lines.append("CTR: could not be reliably calculated (segmentation issue).") |
| |
|
| | lines.extend([ |
| | f"View (model): {view}", |
| | f"Predicted age: {age_pred:.0f} years", |
| | f"Predicted sex: {'Female' if female_prob >= 0.5 else 'Male'} (prob={female_prob:.2f})", |
| | "", |
| | "⚠️ Research/educational use only, NOT for clinical decision-making.", |
| | ]) |
| |
|
| | if view != "PA": |
| | lines.append("⚠️ CTR is normally interpreted on PA view. Interpret with caution.") |
| |
|
| | return blended, "\n".join(lines) |
| |
|
| |
|
| | visual_demo = gr.Interface( |
| | fn=analyze, |
| | inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"), |
| | outputs=[ |
| | gr.Image(label="Segmentation overlay"), |
| | gr.Textbox(label="AI output"), |
| | ], |
| | title="AI CTR helper (research only)", |
| | description=( |
| | "Segments heart and lungs and estimates CTR using 'ianpan/chest-x-ray-basic'. " |
| | "Research use only." |
| | ), |
| | ) |
| |
|
| |
|
| | |
| |
|
| | def get_points(image): |
| | if image is None: |
| | return {"error": "No image uploaded"} |
| |
|
| | ( |
| | img, |
| | mask, |
| | h, |
| | w, |
| | ctr, |
| | thorax_left, |
| | thorax_right, |
| | heart_left, |
| | heart_right, |
| | view_idx, |
| | age_pred, |
| | female_prob, |
| | ) = _run_model(image) |
| |
|
| | result = { |
| | "image_width": w, |
| | "image_height": h, |
| | "ctr": ctr, |
| | "thorax_left_px": thorax_left, |
| | "thorax_right_px": thorax_right, |
| | "heart_left_px": heart_left, |
| | "heart_right_px": heart_right, |
| | "view_idx": int(view_idx), |
| | } |
| | return result |
| |
|
| |
|
| | points_api = gr.Interface( |
| | fn=get_points, |
| | inputs=gr.Image(type="pil", label="Chest X-ray (PNG/JPG) – frontal view"), |
| | outputs=gr.JSON(label="CTR points JSON"), |
| | title="CTR points API", |
| | description="Returns thorax/heart x-coordinates and CTR as JSON.", |
| | api_name="ctr_points", |
| | ) |
| |
|
| | demo = gr.TabbedInterface( |
| | [visual_demo, points_api], |
| | ["Viewer", "JSON API"], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|