| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import torch |
| | import torch.nn.functional as F |
| | from peft import PeftModel |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir): |
| | """ |
| | Initialize the model and tokenizer using the provided model directory. |
| | """ |
| | model_name = "munzirmuneer/phishing_url_gemma_pytorch" |
| | model_name2 = "google/gemma-2b" |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name2) |
| | base_model = AutoModelForSequenceClassification.from_pretrained(model_name) |
| | self.model = PeftModel.from_pretrained(base_model, model_name) |
| |
|
| | def __call__(self, input_data): |
| | """ |
| | Perform inference on the input text and return predictions. |
| | """ |
| | |
| | if 'inputs' in input_data: |
| | input_text = input_data['inputs'] |
| | else: |
| | raise ValueError("Input data must contain the 'inputs' key with a URL.") |
| | |
| | inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | |
| | |
| | logits = outputs.logits |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | |
| | pred_class = torch.argmax(probs, dim=-1) |
| | |
| | return { |
| | "predicted_class": pred_class.item(), |
| | "probabilities": probs[0].tolist() |
| | } |
| |
|