import os import torch import pandas as pd import timm from PIL import Image from torchvision import transforms from scipy.stats import mode import torch.nn.functional as F # CONFIG MUST MATCH TRAINING MODEL_NAME = 'efficientnet_b3' IMAGE_SIZE = (300, 300) NUM_CLASSES = 3 DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" def apply_temporal_smoothing(predictions, window_size=5): smoothed_preds = predictions.copy() for i in range(len(predictions)): start = max(0, i - window_size // 2) end = min(len(predictions), i + window_size // 2 + 1) window = predictions[start:end] most_common = mode(window, keepdims=False)[0] smoothed_preds[i] = most_common return smoothed_preds def run_inference(TEST_IMAGE_PATH, model, SUBMISSION_CSV_SAVE_PATH): model.eval() test_images = os.listdir(TEST_IMAGE_PATH) test_images.sort() transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) raw_predictions = [] print(f"Inference with TTA on {len(test_images)} images...") with torch.no_grad(): for img_name in test_images: img_path = os.path.join(TEST_IMAGE_PATH, img_name) try: # Load Original img_pil = Image.open(img_path).convert("RGB") img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE) # Load Flipped (TTA) img_flip = transform(img_pil.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(DEVICE) # Predict both out1 = model(img_tensor) out2 = model(img_flip) # Average probabilities avg_out = (F.softmax(out1, dim=1) + F.softmax(out2, dim=1)) / 2 pred = torch.argmax(avg_out, dim=1).item() raw_predictions.append(pred) except Exception as e: print(f"Error {img_name}: {e}") raw_predictions.append(0) final_predictions = apply_temporal_smoothing(raw_predictions, window_size=5) df = pd.DataFrame({"file_name": test_images, "category_id": final_predictions}) df.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False) print(f"Saved to {SUBMISSION_CSV_SAVE_PATH}") if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) TEST_PATH = "/tmp/data/test_images" MODEL_PATH = os.path.join(current_dir, "multiclass_model.pth") SUB_PATH = os.path.join(current_dir, "submission.csv") model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model = model.to(DEVICE) run_inference(TEST_PATH, model, SUB_PATH)