| import requests |
| import torch |
| from PIL import Image |
| import hashlib |
| import tempfile |
| import unittest |
| from io import BytesIO |
| from pathlib import Path |
| from unittest.mock import patch |
|
|
| from urllib3 import HTTPResponse |
| from urllib3._collections import HTTPHeaderDict |
|
|
| import open_clip |
| from open_clip.pretrained import download_pretrained_from_url |
|
|
|
|
| class DownloadPretrainedTests(unittest.TestCase): |
|
|
| def create_response(self, data, status_code=200, content_type='application/octet-stream'): |
| fp = BytesIO(data) |
| headers = HTTPHeaderDict({ |
| 'Content-Type': content_type, |
| 'Content-Length': str(len(data)) |
| }) |
| raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code) |
| return raw |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_url_from_openaipublic(self, urllib): |
| file_contents = b'pretrained model weights' |
| expected_hash = hashlib.sha256(file_contents).hexdigest() |
| urllib.request.urlopen.return_value = self.create_response(file_contents) |
| with tempfile.TemporaryDirectory() as root: |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
| download_pretrained_from_url(url, root) |
| urllib.request.urlopen.assert_called_once() |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib): |
| file_contents = b'pretrained model weights' |
| expected_hash = hashlib.sha256(file_contents).hexdigest() |
| urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') |
| with tempfile.TemporaryDirectory() as root: |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
| with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): |
| download_pretrained_from_url(url, root) |
| urllib.request.urlopen.assert_called_once() |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib): |
| file_contents = b'pretrained model weights' |
| expected_hash = hashlib.sha256(file_contents).hexdigest() |
| urllib.request.urlopen.return_value = self.create_response(file_contents) |
| with tempfile.TemporaryDirectory() as root: |
| local_file = Path(root) / 'RN50.pt' |
| local_file.write_bytes(file_contents) |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
| download_pretrained_from_url(url, root) |
| urllib.request.urlopen.assert_not_called() |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib): |
| file_contents = b'pretrained model weights' |
| expected_hash = hashlib.sha256(file_contents).hexdigest() |
| urllib.request.urlopen.return_value = self.create_response(file_contents) |
| with tempfile.TemporaryDirectory() as root: |
| local_file = Path(root) / 'RN50.pt' |
| local_file.write_bytes(b'corrupted pretrained model') |
| url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt' |
| download_pretrained_from_url(url, root) |
| urllib.request.urlopen.assert_called_once() |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_url_from_mlfoundations(self, urllib): |
| file_contents = b'pretrained model weights' |
| expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] |
| urllib.request.urlopen.return_value = self.create_response(file_contents) |
| with tempfile.TemporaryDirectory() as root: |
| url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' |
| download_pretrained_from_url(url, root) |
| urllib.request.urlopen.assert_called_once() |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib): |
| file_contents = b'pretrained model weights' |
| expected_hash = hashlib.sha256(file_contents).hexdigest()[:8] |
| urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model') |
| with tempfile.TemporaryDirectory() as root: |
| url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt' |
| with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'): |
| download_pretrained_from_url(url, root) |
| urllib.request.urlopen.assert_called_once() |
|
|
| @patch('open_clip.pretrained.urllib') |
| def test_download_pretrained_from_hfh(self, urllib): |
| model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model') |
| tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model') |
| img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" |
| image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0) |
| text = tokenizer(["a diagram", "a dog", "a cat"]) |
|
|
| with torch.no_grad(): |
| image_features = model.encode_image(image) |
| text_features = model.encode_text(text) |
| image_features /= image_features.norm(dim=-1, keepdim=True) |
| text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
| text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
|
| self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3)) |
|
|