LH-Tech-AI commited on
Commit
164d0d2
Β·
verified Β·
1 Parent(s): 4f4e7db

Create use.py

Browse files
Files changed (1) hide show
  1. use.py +64 -0
use.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ================================================================
2
+ # πŸ” INFERENCE β€” Load image from URL and correct the rotation
3
+ # ================================================================
4
+
5
+ import requests, torch
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from transformers import ResNetForImageClassification
10
+ import matplotlib.pyplot as plt
11
+
12
+ MODEL_DIR = "/kaggle/working/rotation_model"
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ ANGLES = [0, 90, 180, 270]
15
+
16
+ # ── Load model ──
17
+ model = ResNetForImageClassification.from_pretrained(MODEL_DIR).to(DEVICE).eval()
18
+
19
+ preprocess = transforms.Compose([
20
+ transforms.Resize(256),
21
+ transforms.CenterCrop(224),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24
+ ])
25
+
26
+ def predict_rotation(pil_img: Image.Image) -> dict:
27
+ tensor = preprocess(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
28
+ with torch.no_grad(), torch.cuda.amp.autocast():
29
+ logits = model(pixel_values=tensor).logits
30
+ probs = torch.softmax(logits, dim=1)[0].cpu()
31
+ pred = probs.argmax().item()
32
+
33
+ detected = ANGLES[pred]
34
+ correction = (360 - detected) % 360
35
+ return {"detected": detected, "correction": correction,
36
+ "probs": {f"{a}Β°": f"{probs[i]:.4f}" for i, a in enumerate(ANGLES)}}
37
+
38
+ def correct_image(pil_img: Image.Image, correction: int) -> Image.Image:
39
+ if correction == 90: return pil_img.transpose(Image.ROTATE_90)
40
+ elif correction == 180: return pil_img.transpose(Image.ROTATE_180)
41
+ elif correction == 270: return pil_img.transpose(Image.ROTATE_270)
42
+ return pil_img.copy()
43
+
44
+ def load_url(url: str) -> Image.Image:
45
+ return Image.open(BytesIO(requests.get(url, timeout=15).content)).convert("RGB")
46
+
47
+ # ═══════════════════════════════════════════
48
+ # Directly: Rotated Image from URL
49
+ # ═══════════════════════════════════════════
50
+ def fix_image_from_url(url: str):
51
+ img = load_url(url)
52
+ result = predict_rotation(img)
53
+ corrected = correct_image(img, result["correction"])
54
+
55
+ print(f"πŸ“ Recognized: {result['detected']}Β° | Correction: {result['correction']}Β°")
56
+ print(f"πŸ“Š Probs: {result['probs']}")
57
+
58
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
59
+ axes[0].imshow(img); axes[0].set_title("Input"); axes[0].axis("off")
60
+ axes[1].imshow(corrected); axes[1].set_title("Korrigiert"); axes[1].axis("off")
61
+ plt.tight_layout(); plt.show()
62
+ return corrected
63
+
64
+ corrected = fix_image_from_url("https://lh-tech.de/pexels-ana-ibarra-2152867215-32441547.jpg")