| |
| |
| |
|
|
| import requests, torch |
| from io import BytesIO |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import ResNetForImageClassification |
| import matplotlib.pyplot as plt |
|
|
| MODEL_DIR = "/kaggle/working/rotation_model" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ANGLES = [0, 90, 180, 270] |
|
|
| |
| model = ResNetForImageClassification.from_pretrained(MODEL_DIR).to(DEVICE).eval() |
|
|
| preprocess = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| def predict_rotation(pil_img: Image.Image) -> dict: |
| tensor = preprocess(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) |
| with torch.no_grad(), torch.cuda.amp.autocast(): |
| logits = model(pixel_values=tensor).logits |
| probs = torch.softmax(logits, dim=1)[0].cpu() |
| pred = probs.argmax().item() |
|
|
| detected = ANGLES[pred] |
| correction = (360 - detected) % 360 |
| return {"detected": detected, "correction": correction, |
| "probs": {f"{a}Β°": f"{probs[i]:.4f}" for i, a in enumerate(ANGLES)}} |
|
|
| def correct_image(pil_img: Image.Image, correction: int) -> Image.Image: |
| if correction == 90: return pil_img.transpose(Image.ROTATE_90) |
| elif correction == 180: return pil_img.transpose(Image.ROTATE_180) |
| elif correction == 270: return pil_img.transpose(Image.ROTATE_270) |
| return pil_img.copy() |
|
|
| def load_url(url: str) -> Image.Image: |
| return Image.open(BytesIO(requests.get(url, timeout=15).content)).convert("RGB") |
|
|
| |
| |
| |
| def fix_image_from_url(url: str): |
| img = load_url(url) |
| result = predict_rotation(img) |
| corrected = correct_image(img, result["correction"]) |
|
|
| print(f"π Recognized: {result['detected']}Β° | Correction: {result['correction']}Β°") |
| print(f"π Probs: {result['probs']}") |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| axes[0].imshow(img); axes[0].set_title("Input"); axes[0].axis("off") |
| axes[1].imshow(corrected); axes[1].set_title("Corrected"); axes[1].axis("off") |
| plt.tight_layout(); plt.show() |
| return corrected |
|
|
| corrected = fix_image_from_url("https://lh-tech.de/pexels-ana-ibarra-2152867215-32441547.jpg") |