Update README.md
Browse files
README.md
CHANGED
|
@@ -54,6 +54,22 @@ class OptILMClassifier(nn.Module):
|
|
| 54 |
logits = self.classifier(combined_input)
|
| 55 |
return logits
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def preprocess_input(tokenizer, system_prompt, initial_query):
|
| 58 |
combined_input = f"{system_prompt}\n\nUser: {initial_query}"
|
| 59 |
encoding = tokenizer.encode_plus(
|
|
|
|
| 54 |
logits = self.classifier(combined_input)
|
| 55 |
return logits
|
| 56 |
|
| 57 |
+
|
| 58 |
+
def load_optillm_model():
|
| 59 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
# Load the base model
|
| 61 |
+
base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased")
|
| 62 |
+
# Create the OptILMClassifier
|
| 63 |
+
model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
|
| 64 |
+
model.to(device)
|
| 65 |
+
# Download the safetensors file
|
| 66 |
+
safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
|
| 67 |
+
# Load the state dict from the safetensors file
|
| 68 |
+
load_model(model, safetensors_path)
|
| 69 |
+
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 71 |
+
return model, tokenizer, device
|
| 72 |
+
|
| 73 |
def preprocess_input(tokenizer, system_prompt, initial_query):
|
| 74 |
combined_input = f"{system_prompt}\n\nUser: {initial_query}"
|
| 75 |
encoding = tokenizer.encode_plus(
|