GyroScope / use.py
LH-Tech-AI's picture
Update use.py
2259c4b verified
# ================================================================
# πŸ” INFERENCE β€” Load image from URL and correct the rotation
# ================================================================
import requests, torch
from io import BytesIO
from PIL import Image
from torchvision import transforms
from transformers import ResNetForImageClassification
import matplotlib.pyplot as plt
MODEL_DIR = "/kaggle/working/rotation_model"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ANGLES = [0, 90, 180, 270]
# ── Load model ──
model = ResNetForImageClassification.from_pretrained(MODEL_DIR).to(DEVICE).eval()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def predict_rotation(pil_img: Image.Image) -> dict:
tensor = preprocess(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad(), torch.cuda.amp.autocast():
logits = model(pixel_values=tensor).logits
probs = torch.softmax(logits, dim=1)[0].cpu()
pred = probs.argmax().item()
detected = ANGLES[pred]
correction = (360 - detected) % 360
return {"detected": detected, "correction": correction,
"probs": {f"{a}Β°": f"{probs[i]:.4f}" for i, a in enumerate(ANGLES)}}
def correct_image(pil_img: Image.Image, correction: int) -> Image.Image:
if correction == 90: return pil_img.transpose(Image.ROTATE_90)
elif correction == 180: return pil_img.transpose(Image.ROTATE_180)
elif correction == 270: return pil_img.transpose(Image.ROTATE_270)
return pil_img.copy()
def load_url(url: str) -> Image.Image:
return Image.open(BytesIO(requests.get(url, timeout=15).content)).convert("RGB")
# ═══════════════════════════════════════════
# Directly: Rotated Image from URL
# ═══════════════════════════════════════════
def fix_image_from_url(url: str):
img = load_url(url)
result = predict_rotation(img)
corrected = correct_image(img, result["correction"])
print(f"πŸ“ Recognized: {result['detected']}Β° | Correction: {result['correction']}Β°")
print(f"πŸ“Š Probs: {result['probs']}")
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(img); axes[0].set_title("Input"); axes[0].axis("off")
axes[1].imshow(corrected); axes[1].set_title("Corrected"); axes[1].axis("off")
plt.tight_layout(); plt.show()
return corrected
corrected = fix_image_from_url("https://lh-tech.de/pexels-ana-ibarra-2152867215-32441547.jpg")