Newvel commited on
Commit
584224e
·
verified ·
1 Parent(s): debd0f1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -38
handler.py CHANGED
@@ -2,42 +2,42 @@ import json
2
  from transformers import pipeline
3
 
4
  class EndpointHandler:
5
- def __init__(self, model_dir):
6
- """
7
- Initialize the handler with optimized settings for low-resource environments.
8
- Args:
9
- model_dir (str): The directory where the fine-tuned model is stored.
10
- """
11
- print("Loading summarization pipeline...")
12
- self.summarizer = pipeline(
13
- "summarization",
14
- model=model_dir,
15
- device_map='cpu',
16
- max_length=150,
17
- min_length=30
18
- )
19
 
20
- def __call__(self, inputs):
21
- try:
22
- # If inputs is already a string, try to parse it
23
- if isinstance(inputs, str):
24
- try:
25
- input_data = json.loads(inputs)
26
- except json.JSONDecodeError:
27
- input_data = {"inputs": inputs}
28
- elif isinstance(inputs, dict):
29
- input_data = inputs
30
- else:
31
- raise ValueError("Input must be a string or dictionary")
32
-
33
- input_text = input_data.get("inputs", "")
34
- if not input_text:
35
- return json.dumps({"error": "No input text provided."})
36
-
37
- # Generate the summary
38
- summary = self.summarizer(input_text)[0]["summary_text"]
39
-
40
- return json.dumps({"summary": summary})
41
-
42
- except Exception as e:
43
- return json.dumps({"error": str(e)})
 
2
  from transformers import pipeline
3
 
4
  class EndpointHandler:
5
+ def __init__(self, model_dir):
6
+ """
7
+ Initialize the handler with optimized settings for low-resource environments.
8
+ Args:
9
+ model_dir (str): The directory where the fine-tuned model is stored.
10
+ """
11
+ print("Loading summarization pipeline...")
12
+ self.summarizer = pipeline(
13
+ "summarization",
14
+ model=model_dir,
15
+ device_map='cpu',
16
+ max_length=200,
17
+ min_length=40
18
+ )
19
 
20
+ def __call__(self, inputs):
21
+ try:
22
+ # If inputs is already a string, try to parse it
23
+ if isinstance(inputs, str):
24
+ try:
25
+ input_data = json.loads(inputs)
26
+ except json.JSONDecodeError:
27
+ input_data = {"inputs": inputs}
28
+ elif isinstance(inputs, dict):
29
+ input_data = inputs
30
+ else:
31
+ raise ValueError("Input must be a string or dictionary")
32
+
33
+ input_text = input_data.get("inputs", "")
34
+ if not input_text:
35
+ return json.dumps({"error": "No input text provided."})
36
+
37
+ # Generate the summary
38
+ summary = self.summarizer(input_text)[0]["summary_text"]
39
+
40
+ return json.dumps({"summary": summary})
41
+
42
+ except Exception as e:
43
+ return json.dumps({"error": str(e)})