SCANSKY commited on
Commit
513c09e
·
verified ·
1 Parent(s): e3b4f0e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -38
handler.py CHANGED
@@ -2,8 +2,6 @@ from transformers import DistilBertTokenizer, DistilBertForSequenceClassificatio
2
  import torch
3
  import os
4
 
5
-
6
-
7
  # Initialize model and tokenizer
8
  model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier"
9
  model = None
@@ -22,7 +20,7 @@ load_model_components()
22
  def predict_relevance(text):
23
  """Predict whether a text is relevant or not"""
24
  if not text.strip():
25
- return {"error": "Please provide some text to classify."}
26
 
27
  inputs = tokenizer(
28
  text,
@@ -46,7 +44,7 @@ def predict_relevance(text):
46
 
47
  return {
48
  "prediction": predicted_class, # 1 for relevant, 0 for not relevant
49
- "confidence": float(confidence),
50
  "text": text
51
  }
52
 
@@ -59,43 +57,36 @@ class EndpointHandler:
59
  def preprocess(self, data):
60
  # Extract the input text from the request
61
  text = data.get("inputs", "")
62
- return text
 
 
63
 
64
- def inference(self, text):
65
- if isinstance(text, list):
66
- # Handle batch prediction if multiple texts are provided
67
- results = []
68
- for t in text:
69
- if isinstance(t, dict):
70
- # Handle case where inputs come as list of dicts
71
- t = t.get("inputs", "")
72
- result = predict_relevance(t)
73
- results.append(result)
74
- return results
75
- else:
76
- # Single prediction
77
- return predict_relevance(text)
78
 
79
- def postprocess(self, output):
80
- if isinstance(output, list):
81
- # Process batch results
82
- return [{
83
- "prediction": "Relevant" if item["prediction"] == 1 else "Not Relevant",
84
- "confidence": item["confidence"],
85
- "text": item["text"]
86
- } for item in output]
87
- else:
88
- # Process single result
89
  if "error" in output:
90
- return {"error": output["error"]}
91
- return {
92
- "prediction": "Relevant" if output["prediction"] == 1 else "Not Relevant",
93
- "confidence": output["confidence"],
94
- "text": output["text"]
95
- }
 
 
 
 
 
 
96
 
97
  def __call__(self, data):
98
  # Main method to handle the request
99
- text = self.preprocess(data)
100
- output = self.inference(text)
101
- return self.postprocess(output)
 
2
  import torch
3
  import os
4
 
 
 
5
  # Initialize model and tokenizer
6
  model_name = "SCANSKY/distilbertTourism-multilingual-rclassifier"
7
  model = None
 
20
  def predict_relevance(text):
21
  """Predict whether a text is relevant or not"""
22
  if not text.strip():
23
+ return {"error": "Empty text provided."}
24
 
25
  inputs = tokenizer(
26
  text,
 
44
 
45
  return {
46
  "prediction": predicted_class, # 1 for relevant, 0 for not relevant
47
+ "confidence": float(confidence) * 100, # Convert to percentage
48
  "text": text
49
  }
50
 
 
57
  def preprocess(self, data):
58
  # Extract the input text from the request
59
  text = data.get("inputs", "")
60
+ # Split by newlines and remove empty lines
61
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
62
+ return lines
63
 
64
+ def inference(self, lines):
65
+ results = []
66
+ for line in lines:
67
+ result = predict_relevance(line)
68
+ results.append(result)
69
+ return results
 
 
 
 
 
 
 
 
70
 
71
+ def postprocess(self, outputs):
72
+ processed_results = []
73
+ for output in outputs:
 
 
 
 
 
 
 
74
  if "error" in output:
75
+ processed_results.append({
76
+ "text": output.get("text", ""),
77
+ "error": output["error"],
78
+ "confidence": 0
79
+ })
80
+ else:
81
+ processed_results.append({
82
+ "text": output["text"],
83
+ "confidence": output["confidence"],
84
+ "relevance": "Relevant" if output["prediction"] == 1 else "Not Relevant"
85
+ })
86
+ return processed_results
87
 
88
  def __call__(self, data):
89
  # Main method to handle the request
90
+ lines = self.preprocess(data)
91
+ outputs = self.inference(lines)
92
+ return self.postprocess(outputs)