| import os |
| import torch |
| import pandas as pd |
| import timm |
| from PIL import Image |
| from torchvision import transforms |
| from scipy.stats import mode |
| import torch.nn.functional as F |
|
|
| |
| MODEL_NAME = 'efficientnet_b3' |
| IMAGE_SIZE = (300, 300) |
| NUM_CLASSES = 3 |
| DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
| def apply_temporal_smoothing(predictions, window_size=5): |
| smoothed_preds = predictions.copy() |
| for i in range(len(predictions)): |
| start = max(0, i - window_size // 2) |
| end = min(len(predictions), i + window_size // 2 + 1) |
| window = predictions[start:end] |
| most_common = mode(window, keepdims=False)[0] |
| smoothed_preds[i] = most_common |
| return smoothed_preds |
|
|
| def run_inference(TEST_IMAGE_PATH, model, SUBMISSION_CSV_SAVE_PATH): |
| model.eval() |
| test_images = os.listdir(TEST_IMAGE_PATH) |
| test_images.sort() |
| |
| transform = transforms.Compose([ |
| transforms.Resize(IMAGE_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| |
| raw_predictions = [] |
| print(f"Inference with TTA on {len(test_images)} images...") |
| |
| with torch.no_grad(): |
| for img_name in test_images: |
| img_path = os.path.join(TEST_IMAGE_PATH, img_name) |
| try: |
| |
| img_pil = Image.open(img_path).convert("RGB") |
| img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE) |
| |
| |
| img_flip = transform(img_pil.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(DEVICE) |
| |
| |
| out1 = model(img_tensor) |
| out2 = model(img_flip) |
| |
| |
| avg_out = (F.softmax(out1, dim=1) + F.softmax(out2, dim=1)) / 2 |
| |
| pred = torch.argmax(avg_out, dim=1).item() |
| raw_predictions.append(pred) |
| except Exception as e: |
| print(f"Error {img_name}: {e}") |
| raw_predictions.append(0) |
| |
| final_predictions = apply_temporal_smoothing(raw_predictions, window_size=5) |
| |
| df = pd.DataFrame({"file_name": test_images, "category_id": final_predictions}) |
| df.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False) |
| print(f"Saved to {SUBMISSION_CSV_SAVE_PATH}") |
|
|
| if __name__ == "__main__": |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| TEST_PATH = "/tmp/data/test_images" |
| MODEL_PATH = os.path.join(current_dir, "multiclass_model.pth") |
| SUB_PATH = os.path.join(current_dir, "submission.csv") |
| |
| model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES) |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) |
| model = model.to(DEVICE) |
| |
| run_inference(TEST_PATH, model, SUB_PATH) |