# ================================================================ # 🔍 INFERENCE — Load image from URL and correct the rotation # ================================================================ import requests, torch from io import BytesIO from PIL import Image from torchvision import transforms from transformers import ResNetForImageClassification import matplotlib.pyplot as plt MODEL_DIR = "/kaggle/working/rotation_model" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ANGLES = [0, 90, 180, 270] # ── Load model ── model = ResNetForImageClassification.from_pretrained(MODEL_DIR).to(DEVICE).eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def predict_rotation(pil_img: Image.Image) -> dict: tensor = preprocess(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) with torch.no_grad(), torch.cuda.amp.autocast(): logits = model(pixel_values=tensor).logits probs = torch.softmax(logits, dim=1)[0].cpu() pred = probs.argmax().item() detected = ANGLES[pred] correction = (360 - detected) % 360 return {"detected": detected, "correction": correction, "probs": {f"{a}°": f"{probs[i]:.4f}" for i, a in enumerate(ANGLES)}} def correct_image(pil_img: Image.Image, correction: int) -> Image.Image: if correction == 90: return pil_img.transpose(Image.ROTATE_90) elif correction == 180: return pil_img.transpose(Image.ROTATE_180) elif correction == 270: return pil_img.transpose(Image.ROTATE_270) return pil_img.copy() def load_url(url: str) -> Image.Image: return Image.open(BytesIO(requests.get(url, timeout=15).content)).convert("RGB") # ═══════════════════════════════════════════ # Directly: Rotated Image from URL # ═══════════════════════════════════════════ def fix_image_from_url(url: str): img = load_url(url) result = predict_rotation(img) corrected = correct_image(img, result["correction"]) print(f"📐 Recognized: {result['detected']}° | Correction: {result['correction']}°") print(f"📊 Probs: {result['probs']}") fig, axes = plt.subplots(1, 2, figsize=(12, 5)) axes[0].imshow(img); axes[0].set_title("Input"); axes[0].axis("off") axes[1].imshow(corrected); axes[1].set_title("Corrected"); axes[1].axis("off") plt.tight_layout(); plt.show() return corrected corrected = fix_image_from_url("https://lh-tech.de/pexels-ana-ibarra-2152867215-32441547.jpg")