munzirmuneer commited on
Commit
82797d0
·
verified ·
1 Parent(s): 98c6ec2

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from peft import PeftModel
5
+
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_dir):
9
+ """
10
+ Initialize the model and tokenizer using the provided model directory.
11
+ """
12
+ model_name = "munzirmuneer/phishing_url_gemma_pytorch" # Replace with your specific model
13
+ model_name2 = "google/gemma-2b"
14
+
15
+ # Load tokenizer and model
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name2)
17
+ base_model = AutoModelForSequenceClassification.from_pretrained(model_name)
18
+ self.model = PeftModel.from_pretrained(base_model, model_name)
19
+
20
+ def __call__(self, input_data):
21
+ """
22
+ Perform inference on the input text and return predictions.
23
+ """
24
+ # Extract the URL from the input_data dictionary
25
+ if 'inputs' in input_data:
26
+ input_text = input_data['inputs'] # Expecting a single URL as a string
27
+ else:
28
+ raise ValueError("Input data must contain the 'inputs' key with a URL.")
29
+ # Tokenize input
30
+ inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
31
+
32
+ # Run inference
33
+ with torch.no_grad():
34
+ outputs = self.model(**inputs)
35
+
36
+ # Get logits and probabilities
37
+ logits = outputs.logits
38
+ probs = F.softmax(logits, dim=-1)
39
+
40
+ # Get the predicted class (highest probability)
41
+ pred_class = torch.argmax(probs, dim=-1)
42
+
43
+ return {
44
+ "predicted_class": pred_class.item(),
45
+ "probabilities": probs[0].tolist()
46
+ }