sxtran commited on
Commit
7a92108
·
verified ·
1 Parent(s): e2f28aa

update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +130 -0
handler.py CHANGED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+ from difflib import SequenceMatcher
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # Load model and tokenizer
8
+ model_name = path if path else "sxtran/paraphraser-ielts-t5-base"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.model.to(self.device)
13
+
14
+ def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
15
+ # Your existing paraphrase_batch logic
16
+ inputs = self.tokenizer(
17
+ sentences,
18
+ padding=True,
19
+ truncation=True,
20
+ max_length=512,
21
+ return_tensors="pt"
22
+ ).to(self.device)
23
+ outputs = self.model.generate(
24
+ **inputs,
25
+ max_length=512,
26
+ num_beams=5,
27
+ temperature=temperature,
28
+ num_return_sequences=num_return_sequences,
29
+ early_stopping=True
30
+ )
31
+ decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
32
+ if num_return_sequences > 1:
33
+ grouped = [
34
+ decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
35
+ for i in range(len(sentences))
36
+ ]
37
+ return grouped
38
+ else:
39
+ return decoded
40
+
41
+ def compute_changes(self, original, enhanced):
42
+ # Your existing compute_changes logic
43
+ changes = []
44
+ matcher = SequenceMatcher(None, original.split(), enhanced.split())
45
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
46
+ if tag in ("replace", "insert", "delete"):
47
+ original_phrase = " ".join(original.split()[i1:i2])
48
+ new_phrase = " ".join(enhanced.split()[j1:j2])
49
+ changes.append({
50
+ "original_phrase": original_phrase,
51
+ "new_phrase": new_phrase,
52
+ "char_start": i1,
53
+ "char_end": i2,
54
+ "token_start": i1,
55
+ "token_end": i2,
56
+ "explanation": f"{tag} change",
57
+ "error_type": "",
58
+ "tip": ""
59
+ })
60
+ return changes
61
+
62
+ def __call__(self, inputs):
63
+ # This method is the main entry point for the Hugging Face Endpoint.
64
+
65
+ # Check for both standard and wrapped JSON inputs
66
+ if isinstance(inputs, list):
67
+ sentences = inputs
68
+ parameters = {}
69
+ elif isinstance(inputs, dict):
70
+ # Check for the common {"inputs": "...", "parameters": {}} format
71
+ sentences = inputs.get("inputs", [])
72
+ # If inputs is a single string, wrap it in a list
73
+ if isinstance(sentences, str):
74
+ sentences = [sentences]
75
+ parameters = inputs.get("parameters", {})
76
+ else:
77
+ return {
78
+ "success": False,
79
+ "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
80
+ }
81
+
82
+ # Handle optional parameters
83
+ num_return_sequences = parameters.get("num_return_sequences", 1)
84
+ temperature = parameters.get("temperature", 1.0)
85
+
86
+ if not sentences:
87
+ return {
88
+ "success": False,
89
+ "error": "No sentences provided."
90
+ }
91
+
92
+ try:
93
+ paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
94
+ results = []
95
+
96
+ if num_return_sequences > 1:
97
+ # Logic for multiple return sequences
98
+ for i, orig in enumerate(sentences):
99
+ for cand in paraphrased[i]:
100
+ results.append({
101
+ "original_sentence": orig,
102
+ "enhanced_sentence": cand,
103
+ "changes": self.compute_changes(orig, cand)
104
+ })
105
+ else:
106
+ # Logic for single return sequence
107
+ for orig, cand in zip(sentences, paraphrased):
108
+ results.append({
109
+ "original_sentence": orig,
110
+ "enhanced_sentence": cand,
111
+ "changes": self.compute_changes(orig, cand)
112
+ })
113
+
114
+ return {
115
+ "success": True,
116
+ "results": results,
117
+ "sentences_count": len(sentences),
118
+ "processed_count": len(results),
119
+ "skipped_count": 0,
120
+ "error_count": 0
121
+ }
122
+ except Exception as e:
123
+ return {
124
+ "success": False,
125
+ "error": str(e),
126
+ "sentences_count": len(sentences),
127
+ "processed_count": 0,
128
+ "skipped_count": 0,
129
+ "error_count": 1
130
+ }