Updates inference notebook.
Browse files
model/notebooks/inference.ipynb
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
"import torch\n",
|
| 17 |
"\n",
|
| 18 |
"from huggingface_hub import hf_hub_download\n",
|
| 19 |
-
"from transformers import AutoTokenizer\n",
|
| 20 |
"\n",
|
| 21 |
"from model.distilbert import DistilBertClassificationModel\n",
|
| 22 |
"from model.scibert import SciBertClassificationModel\n",
|
|
@@ -39,7 +39,7 @@
|
|
| 39 |
"outputs": [],
|
| 40 |
"source": [
|
| 41 |
"# Baseline\n",
|
| 42 |
-
"repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
|
| 43 |
"# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
|
| 44 |
"# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
|
| 45 |
"# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
|
|
@@ -48,7 +48,7 @@
|
|
| 48 |
"# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
|
| 49 |
"# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
|
| 50 |
"# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
|
| 51 |
-
"
|
| 52 |
"\n",
|
| 53 |
"# Initialize the model\n",
|
| 54 |
"model = DistilBertClassificationModel(repo_id)\n",
|
|
@@ -63,8 +63,11 @@
|
|
| 63 |
"metadata": {},
|
| 64 |
"outputs": [],
|
| 65 |
"source": [
|
| 66 |
-
"#
|
| 67 |
-
"tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
|
|
|
|
|
|
|
|
|
|
| 68 |
"\n",
|
| 69 |
"# Loads classification head weights\n",
|
| 70 |
"classification_head_path = hf_hub_download(\n",
|
|
@@ -84,10 +87,10 @@
|
|
| 84 |
"outputs": [],
|
| 85 |
"source": [
|
| 86 |
"# Baseline\n",
|
| 87 |
-
"text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
|
| 88 |
"\n",
|
| 89 |
"# Prompt\n",
|
| 90 |
-
"
|
| 91 |
]
|
| 92 |
},
|
| 93 |
{
|
|
@@ -99,7 +102,7 @@
|
|
| 99 |
"# Tokenize inputs \n",
|
| 100 |
"inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
|
| 101 |
"\n",
|
| 102 |
-
"# For
|
| 103 |
"inputs_kwargs = {}\n",
|
| 104 |
"for key, value in inputs.items():\n",
|
| 105 |
" if key not in [\"token_type_ids\"]:\n",
|
|
|
|
| 16 |
"import torch\n",
|
| 17 |
"\n",
|
| 18 |
"from huggingface_hub import hf_hub_download\n",
|
| 19 |
+
"from transformers import AutoTokenizer, T5Tokenizer\n",
|
| 20 |
"\n",
|
| 21 |
"from model.distilbert import DistilBertClassificationModel\n",
|
| 22 |
"from model.scibert import SciBertClassificationModel\n",
|
|
|
|
| 39 |
"outputs": [],
|
| 40 |
"source": [
|
| 41 |
"# Baseline\n",
|
| 42 |
+
"# repo_id = \"ppak10/defect-classification-distilbert-baseline-25-epochs\"\n",
|
| 43 |
"# repo_id = \"ppak10/defect-classification-scibert-baseline-25-epochs\"\n",
|
| 44 |
"# repo_id = \"ppak10/defect-classification-llama-baseline-25-epochs\"\n",
|
| 45 |
"# repo_id = \"ppak10/defect-classification-t5-baseline-25-epochs\"\n",
|
|
|
|
| 48 |
"# repo_id = \"ppak10/defect-classification-distilbert-prompt-02-epochs\"\n",
|
| 49 |
"# repo_id = \"ppak10/defect-classification-scibert-prompt-02-epochs\"\n",
|
| 50 |
"# repo_id = \"ppak10/defect-classification-llama-prompt-02-epochs\"\n",
|
| 51 |
+
"repo_id = \"ppak10/defect-classification-t5-prompt-02-epochs\"\n",
|
| 52 |
"\n",
|
| 53 |
"# Initialize the model\n",
|
| 54 |
"model = DistilBertClassificationModel(repo_id)\n",
|
|
|
|
| 63 |
"metadata": {},
|
| 64 |
"outputs": [],
|
| 65 |
"source": [
|
| 66 |
+
"# Uncomment for DistilBERT, SciBERT, and Llama\n",
|
| 67 |
+
"# tokenizer = AutoTokenizer.from_pretrained(repo_id)\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# Uncomment for T5\n",
|
| 70 |
+
"tokenizer = T5Tokenizer.from_pretrained(repo_id)\n",
|
| 71 |
"\n",
|
| 72 |
"# Loads classification head weights\n",
|
| 73 |
"classification_head_path = hf_hub_download(\n",
|
|
|
|
| 87 |
"outputs": [],
|
| 88 |
"source": [
|
| 89 |
"# Baseline\n",
|
| 90 |
+
"# text = \"Ti-6Al-4V[SEP]280.0 W[SEP]400.0 mm/s[SEP]100.0 microns[SEP]50.0 microns[SEP]100.0 microns\"\n",
|
| 91 |
"\n",
|
| 92 |
"# Prompt\n",
|
| 93 |
+
"text = \"What are the likely imperfections that occur in Ti-6Al-4V L-PBF builds at 280.0 W, given a 100.0 microns beam diameter, a 400.0 mm/s scan speed, a 100.0 microns hatch spacing, and a 50.0 microns layer height?\""
|
| 94 |
]
|
| 95 |
},
|
| 96 |
{
|
|
|
|
| 102 |
"# Tokenize inputs \n",
|
| 103 |
"inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
|
| 104 |
"\n",
|
| 105 |
+
"# For SciBERT specific case. \n",
|
| 106 |
"inputs_kwargs = {}\n",
|
| 107 |
"for key, value in inputs.items():\n",
|
| 108 |
" if key not in [\"token_type_ids\"]:\n",
|