Sontran0108 commited on
Commit
a43192e
·
1 Parent(s): 5637bcd

Add CoEdIT-Large model with inference handler

Browse files
Files changed (1) hide show
  1. handler.py +135 -0
handler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
2
+ import torch
3
+ from difflib import SequenceMatcher
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # Load model and tokenizer
8
+ model_name = path if path else "grammarly/coedit-large"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name)
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.model.to(self.device)
13
+
14
+ def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
15
+ # Add the text editing prefix to each sentence
16
+ prefix = "Fix the grammar: "
17
+ sentences_with_prefix = [prefix + s for s in sentences]
18
+
19
+ inputs = self.tokenizer(
20
+ sentences_with_prefix,
21
+ padding=True,
22
+ truncation=True,
23
+ max_length=512,
24
+ return_tensors="pt"
25
+ ).to(self.device)
26
+
27
+ outputs = self.model.generate(
28
+ **inputs,
29
+ max_length=512,
30
+ num_beams=5,
31
+ temperature=temperature,
32
+ num_return_sequences=num_return_sequences,
33
+ early_stopping=True
34
+ )
35
+
36
+ decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
37
+ if num_return_sequences > 1:
38
+ grouped = [
39
+ decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
40
+ for i in range(len(sentences))
41
+ ]
42
+ return grouped
43
+ else:
44
+ return decoded
45
+
46
+ def compute_changes(self, original, enhanced):
47
+ # Your existing compute_changes logic
48
+ changes = []
49
+ matcher = SequenceMatcher(None, original.split(), enhanced.split())
50
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
51
+ if tag in ("replace", "insert", "delete"):
52
+ original_phrase = " ".join(original.split()[i1:i2])
53
+ new_phrase = " ".join(enhanced.split()[j1:j2])
54
+ changes.append({
55
+ "original_phrase": original_phrase,
56
+ "new_phrase": new_phrase,
57
+ "char_start": i1,
58
+ "char_end": i2,
59
+ "token_start": i1,
60
+ "token_end": i2,
61
+ "explanation": f"{tag} change",
62
+ "error_type": "",
63
+ "tip": ""
64
+ })
65
+ return changes
66
+
67
+ def __call__(self, inputs):
68
+ # This method is the main entry point for the Hugging Face Endpoint.
69
+
70
+ # Check for both standard and wrapped JSON inputs
71
+ if isinstance(inputs, list):
72
+ sentences = inputs
73
+ parameters = {}
74
+ elif isinstance(inputs, dict):
75
+ # Check for the common {"inputs": "...", "parameters": {}} format
76
+ sentences = inputs.get("inputs", [])
77
+ # If inputs is a single string, wrap it in a list
78
+ if isinstance(sentences, str):
79
+ sentences = [sentences]
80
+ parameters = inputs.get("parameters", {})
81
+ else:
82
+ return {
83
+ "success": False,
84
+ "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
85
+ }
86
+
87
+ # Handle optional parameters
88
+ num_return_sequences = parameters.get("num_return_sequences", 1)
89
+ temperature = parameters.get("temperature", 1.0)
90
+
91
+ if not sentences:
92
+ return {
93
+ "success": False,
94
+ "error": "No sentences provided."
95
+ }
96
+
97
+ try:
98
+ paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
99
+ results = []
100
+
101
+ if num_return_sequences > 1:
102
+ # Logic for multiple return sequences
103
+ for i, orig in enumerate(sentences):
104
+ for cand in paraphrased[i]:
105
+ results.append({
106
+ "original_sentence": orig,
107
+ "enhanced_sentence": cand,
108
+ "changes": self.compute_changes(orig, cand)
109
+ })
110
+ else:
111
+ # Logic for single return sequence
112
+ for orig, cand in zip(sentences, paraphrased):
113
+ results.append({
114
+ "original_sentence": orig,
115
+ "enhanced_sentence": cand,
116
+ "changes": self.compute_changes(orig, cand)
117
+ })
118
+
119
+ return {
120
+ "success": True,
121
+ "results": results,
122
+ "sentences_count": len(sentences),
123
+ "processed_count": len(results),
124
+ "skipped_count": 0,
125
+ "error_count": 0
126
+ }
127
+ except Exception as e:
128
+ return {
129
+ "success": False,
130
+ "error": str(e),
131
+ "sentences_count": len(sentences),
132
+ "processed_count": 0,
133
+ "skipped_count": 0,
134
+ "error_count": 1
135
+ }