Spaces:
Sleeping
Sleeping
| import torch | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import sys | |
| # Add current directory to path | |
| sys.path.append(os.getcwd()) | |
| from classification.model import CliniScanClassifier | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| def get_inference_transforms(): | |
| return A.Compose([ | |
| A.Resize(256, 256), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2() | |
| ]) | |
| def predict(image_path, model_path, device): | |
| # 1. Load Model | |
| num_classes = 15 | |
| model = CliniScanClassifier(num_classes) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # 2. Load and Preprocess Image | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| transform = get_inference_transforms() | |
| transformed = transform(image=image) | |
| image_tensor = transformed['image'].unsqueeze(0).to(device) | |
| # 3. Inference | |
| with torch.no_grad(): | |
| logits = model(image_tensor) | |
| probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| # 4. Class Names (Standard VinDr-CXR) | |
| class_names = [ | |
| "Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly", | |
| "Consolidation", "ILD", "Infiltration", "Lung Opacity", | |
| "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening", | |
| "Pneumothorax", "Pulmonary fibrosis", "No finding" | |
| ] | |
| # 5. Get Top Predictions | |
| results = [] | |
| for i, prob in enumerate(probs): | |
| if prob > 0.2: # 20% Threshold for medical findings | |
| results.append((class_names[i], prob)) | |
| # Sort by confidence | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| return results | |
| if __name__ == '__main__': | |
| # Updated path to the models/ folder | |
| MODEL_PATH = 'models/best_resnet_classification.pth' | |
| os.makedirs('results', exist_ok=True) | |
| # Pick 3 samples from data/images | |
| img_list = [f for f in os.listdir('data/images') if f.endswith('.png')] | |
| if not img_list: | |
| print("No images found in data/images!") | |
| else: | |
| import random | |
| random.seed(42) | |
| samples = random.sample(img_list, min(3, len(img_list))) | |
| device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') | |
| print(f"--- CliniScan AI Proof-of-Work Generator ---") | |
| for i, img_name in enumerate(samples): | |
| image_path = os.path.join('data/images', img_name) | |
| predictions = predict(image_path, MODEL_PATH, device) | |
| # Load original image for drawing | |
| orig_img = cv2.imread(image_path) | |
| # Drawing overlay | |
| y_offset = 30 | |
| cv2.putText(orig_img, "AI DIAGNOSIS", (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| y_offset += 30 | |
| if not predictions: | |
| cv2.putText(orig_img, "No abnormalities detected", (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) | |
| else: | |
| for name, score in predictions[:3]: # Top 3 | |
| text = f"{name}: {score:.1%}" | |
| cv2.putText(orig_img, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) | |
| y_offset += 25 | |
| save_path = f"results/evidence_{i+1}.png" | |
| cv2.imwrite(save_path, orig_img) | |
| print(f"Generated Evidence: {save_path}") | |