File size: 2,773 Bytes
164d0d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2259c4b
164d0d2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# ================================================================
#  πŸ”  INFERENCE β€” Load image from URL and correct the rotation
# ================================================================

import requests, torch
from io import BytesIO
from PIL import Image
from torchvision import transforms
from transformers import ResNetForImageClassification
import matplotlib.pyplot as plt

MODEL_DIR = "/kaggle/working/rotation_model"
DEVICE    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ANGLES    = [0, 90, 180, 270]

# ── Load model ──
model = ResNetForImageClassification.from_pretrained(MODEL_DIR).to(DEVICE).eval()

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def predict_rotation(pil_img: Image.Image) -> dict:
    tensor = preprocess(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
    with torch.no_grad(), torch.cuda.amp.autocast():
        logits = model(pixel_values=tensor).logits
    probs = torch.softmax(logits, dim=1)[0].cpu()
    pred  = probs.argmax().item()

    detected   = ANGLES[pred]
    correction = (360 - detected) % 360
    return {"detected": detected, "correction": correction,
            "probs": {f"{a}Β°": f"{probs[i]:.4f}" for i, a in enumerate(ANGLES)}}

def correct_image(pil_img: Image.Image, correction: int) -> Image.Image:
    if   correction == 90:  return pil_img.transpose(Image.ROTATE_90)
    elif correction == 180: return pil_img.transpose(Image.ROTATE_180)
    elif correction == 270: return pil_img.transpose(Image.ROTATE_270)
    return pil_img.copy()

def load_url(url: str) -> Image.Image:
    return Image.open(BytesIO(requests.get(url, timeout=15).content)).convert("RGB")

# ═══════════════════════════════════════════
#   Directly: Rotated Image from URL
# ═══════════════════════════════════════════
def fix_image_from_url(url: str):
    img = load_url(url)
    result = predict_rotation(img)
    corrected = correct_image(img, result["correction"])

    print(f"πŸ“ Recognized: {result['detected']}Β° | Correction: {result['correction']}Β°")
    print(f"πŸ“Š Probs: {result['probs']}")

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].imshow(img);       axes[0].set_title("Input");      axes[0].axis("off")
    axes[1].imshow(corrected); axes[1].set_title("Corrected");  axes[1].axis("off")
    plt.tight_layout(); plt.show()
    return corrected

corrected = fix_image_from_url("https://lh-tech.de/pexels-ana-ibarra-2152867215-32441547.jpg")