File size: 783 Bytes
f50d9fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import requests
import torch
from PIL import Image
from safetensors import safe_open
def download_sample_image() -> Image.Image:
"""Download chest X-ray with CC license."""
base_url = "https://upload.wikimedia.org/wikipedia/commons"
path = "2/20/Chest_X-ray_in_influenza_and_Haemophilus_influenzae.jpg"
image_url = f"{base_url}/{path}"
headers = {"User-Agent": "RAD-DINO"}
response = requests.get(image_url, headers=headers, stream=True)
return Image.open(response.raw)
def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
state_dict = {}
with safe_open(checkpoint_path, framework="pt") as ckpt_file:
for key in ckpt_file.keys():
state_dict[key] = ckpt_file.get_tensor(key)
return state_dict
|