Reyad-Ahmmed commited on
Commit
bde1ce5
·
verified ·
1 Parent(s): eb45de7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -7
handler.py CHANGED
@@ -1,22 +1,29 @@
1
  import json
2
  from datetime import datetime
 
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, model_dir):
6
  self.model_dir = model_dir
7
 
8
  def load(self):
9
- print("Model loading skipped for Hello World API.")
 
 
 
 
 
 
10
 
11
  def __call__(self, inputs):
12
  """
13
- Hugging Face expects the `EndpointHandler` to be callable.
14
- So, we define `__call__` instead of `predict`.
15
  """
16
  try:
17
- # Hugging Face sends input as a LIST or DICT, so handle both cases
18
  if isinstance(inputs, list) and len(inputs) > 0:
19
- user_text = inputs[0] # Extract text from the list
20
  elif isinstance(inputs, dict) and "inputs" in inputs:
21
  user_text = inputs["inputs"]
22
  else:
@@ -25,8 +32,22 @@ class EndpointHandler:
25
  # Generate timestamp
26
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
27
 
28
- # Return formatted message
29
- return {"message": f"Received at {current_time}: {user_text}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  except Exception as e:
32
  return {"error": f"Unexpected error: {str(e)}"}
 
1
  import json
2
  from datetime import datetime
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ import torch
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
  self.model_dir = model_dir
9
 
10
  def load(self):
11
+ """
12
+ Load a simple DistilBERT model for text classification.
13
+ """
14
+ model_name = "distilbert-base-uncased-finetuned-sst-2-english" # Pretrained model for sentiment analysis
15
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
17
+ print(f"Loaded model: {model_name}")
18
 
19
  def __call__(self, inputs):
20
  """
21
+ Process user input and classify the text using DistilBERT.
 
22
  """
23
  try:
24
+ # Handle different input formats
25
  if isinstance(inputs, list) and len(inputs) > 0:
26
+ user_text = inputs[0]
27
  elif isinstance(inputs, dict) and "inputs" in inputs:
28
  user_text = inputs["inputs"]
29
  else:
 
32
  # Generate timestamp
33
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
34
 
35
+ # Tokenize input text
36
+ inputs = self.tokenizer(user_text, return_tensors="pt")
37
+
38
+ # Perform inference
39
+ with torch.no_grad():
40
+ outputs = self.model(**inputs)
41
+
42
+ # Get predicted label (0 = negative, 1 = positive)
43
+ predicted_label = torch.argmax(outputs.logits).item()
44
+ label_map = {0: "negative", 1: "positive"}
45
+
46
+ return {
47
+ "timestamp": current_time,
48
+ "input_text": user_text,
49
+ "predicted_label": label_map[predicted_label]
50
+ }
51
 
52
  except Exception as e:
53
  return {"error": f"Unexpected error: {str(e)}"}