| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import subprocess |
| import sys |
|
|
| from transformers import BertConfig, BertModel, BertTokenizer, pipeline |
| from transformers.testing_utils import TestCasePlus, require_torch |
|
|
|
|
| class OfflineTests(TestCasePlus): |
| @require_torch |
| def test_offline_mode(self): |
| |
| |
| |
|
|
| |
|
|
| |
| load = """ |
| from transformers import BertConfig, BertModel, BertTokenizer, pipeline |
| """ |
|
|
| run = """ |
| mname = "hf-internal-testing/tiny-random-bert" |
| BertConfig.from_pretrained(mname) |
| BertModel.from_pretrained(mname) |
| BertTokenizer.from_pretrained(mname) |
| pipe = pipeline(task="fill-mask", model=mname) |
| print("success") |
| """ |
|
|
| mock = """ |
| import socket |
| def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet") |
| socket.socket = offline_socket |
| """ |
|
|
| |
| mname = "hf-internal-testing/tiny-random-bert" |
| BertConfig.from_pretrained(mname) |
| BertModel.from_pretrained(mname) |
| BertTokenizer.from_pretrained(mname) |
| pipeline(task="fill-mask", model=mname) |
|
|
| |
| cmd = [sys.executable, "-c", "\n".join([load, run, mock])] |
|
|
| |
| env = self.get_env() |
| |
| env["TRANSFORMERS_OFFLINE"] = "1" |
| result = subprocess.run(cmd, env=env, check=False, capture_output=True) |
| self.assertEqual(result.returncode, 0, result.stderr) |
| self.assertIn("success", result.stdout.decode()) |
|
|
| @require_torch |
| def test_offline_mode_no_internet(self): |
| |
| |
| load = """ |
| from transformers import BertConfig, BertModel, BertTokenizer, pipeline |
| """ |
|
|
| run = """ |
| mname = "hf-internal-testing/tiny-random-bert" |
| BertConfig.from_pretrained(mname) |
| BertModel.from_pretrained(mname) |
| BertTokenizer.from_pretrained(mname) |
| pipe = pipeline(task="fill-mask", model=mname) |
| print("success") |
| """ |
|
|
| mock = """ |
| import socket |
| def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet") |
| socket.socket = offline_socket |
| """ |
|
|
| |
| mname = "hf-internal-testing/tiny-random-bert" |
| BertConfig.from_pretrained(mname) |
| BertModel.from_pretrained(mname) |
| BertTokenizer.from_pretrained(mname) |
| pipeline(task="fill-mask", model=mname) |
|
|
| |
| cmd = [sys.executable, "-c", "\n".join([load, run, mock])] |
|
|
| |
| env = self.get_env() |
| result = subprocess.run(cmd, env=env, check=False, capture_output=True) |
| self.assertEqual(result.returncode, 0, result.stderr) |
| self.assertIn("success", result.stdout.decode()) |
|
|
| @require_torch |
| def test_offline_mode_sharded_checkpoint(self): |
| |
| |
| |
|
|
| |
|
|
| |
| load = """ |
| from transformers import BertConfig, BertModel, BertTokenizer |
| """ |
|
|
| run = """ |
| mname = "hf-internal-testing/tiny-random-bert-sharded" |
| BertConfig.from_pretrained(mname) |
| BertModel.from_pretrained(mname) |
| print("success") |
| """ |
|
|
| mock = """ |
| import socket |
| def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled") |
| socket.socket = offline_socket |
| """ |
|
|
| |
| cmd = [sys.executable, "-c", "\n".join([load, run])] |
|
|
| |
| env = self.get_env() |
| result = subprocess.run(cmd, env=env, check=False, capture_output=True) |
| self.assertEqual(result.returncode, 0, result.stderr) |
| self.assertIn("success", result.stdout.decode()) |
|
|
| |
| cmd = [sys.executable, "-c", "\n".join([load, mock, run])] |
|
|
| |
| |
| |
| |
|
|
| |
| env["TRANSFORMERS_OFFLINE"] = "1" |
| result = subprocess.run(cmd, env=env, check=False, capture_output=True) |
| self.assertEqual(result.returncode, 0, result.stderr) |
| self.assertIn("success", result.stdout.decode()) |
|
|
| @require_torch |
| def test_offline_mode_pipeline_exception(self): |
| load = """ |
| from transformers import pipeline |
| """ |
| run = """ |
| mname = "hf-internal-testing/tiny-random-bert" |
| pipe = pipeline(model=mname) |
| """ |
|
|
| mock = """ |
| import socket |
| def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") |
| socket.socket = offline_socket |
| """ |
| env = self.get_env() |
| env["TRANSFORMERS_OFFLINE"] = "1" |
| cmd = [sys.executable, "-c", "\n".join([load, mock, run])] |
| result = subprocess.run(cmd, env=env, check=False, capture_output=True) |
| self.assertEqual(result.returncode, 1, result.stderr) |
| self.assertIn( |
| "You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode() |
| ) |
|
|