deepl / script.py
yusufbardolia's picture
Upload 9 files
9946dd2 verified
import os
import torch
import pandas as pd
import timm
from PIL import Image
from torchvision import transforms
from scipy.stats import mode
import torch.nn.functional as F
# CONFIG MUST MATCH TRAINING
MODEL_NAME = 'efficientnet_b3'
IMAGE_SIZE = (300, 300)
NUM_CLASSES = 3
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
def apply_temporal_smoothing(predictions, window_size=5):
smoothed_preds = predictions.copy()
for i in range(len(predictions)):
start = max(0, i - window_size // 2)
end = min(len(predictions), i + window_size // 2 + 1)
window = predictions[start:end]
most_common = mode(window, keepdims=False)[0]
smoothed_preds[i] = most_common
return smoothed_preds
def run_inference(TEST_IMAGE_PATH, model, SUBMISSION_CSV_SAVE_PATH):
model.eval()
test_images = os.listdir(TEST_IMAGE_PATH)
test_images.sort()
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
raw_predictions = []
print(f"Inference with TTA on {len(test_images)} images...")
with torch.no_grad():
for img_name in test_images:
img_path = os.path.join(TEST_IMAGE_PATH, img_name)
try:
# Load Original
img_pil = Image.open(img_path).convert("RGB")
img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
# Load Flipped (TTA)
img_flip = transform(img_pil.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(DEVICE)
# Predict both
out1 = model(img_tensor)
out2 = model(img_flip)
# Average probabilities
avg_out = (F.softmax(out1, dim=1) + F.softmax(out2, dim=1)) / 2
pred = torch.argmax(avg_out, dim=1).item()
raw_predictions.append(pred)
except Exception as e:
print(f"Error {img_name}: {e}")
raw_predictions.append(0)
final_predictions = apply_temporal_smoothing(raw_predictions, window_size=5)
df = pd.DataFrame({"file_name": test_images, "category_id": final_predictions})
df.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
print(f"Saved to {SUBMISSION_CSV_SAVE_PATH}")
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = "/tmp/data/test_images"
MODEL_PATH = os.path.join(current_dir, "multiclass_model.pth")
SUB_PATH = os.path.join(current_dir, "submission.csv")
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
run_inference(TEST_PATH, model, SUB_PATH)