Furqan1111 commited on
Commit
b482a43
·
verified ·
1 Parent(s): 74527c1

Create grammar_agent.py

Browse files
Files changed (1) hide show
  1. agents/grammar_agent.py +128 -0
agents/grammar_agent.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import difflib
5
+
6
+
7
+ class GrammarAgent:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "vennify/t5-base-grammar-correction",
11
+ device: Optional[str] = None,
12
+ ):
13
+ """
14
+ Grammar Agent
15
+ - model_name: HF model id
16
+ - device: "cuda" or "cpu" (auto-detect if None)
17
+ """
18
+ self.model_name = model_name
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+
22
+ if device is None:
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ else:
25
+ self.device = device
26
+
27
+ self.model.to(self.device)
28
+
29
+ def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
30
+ """
31
+ Internal helper to call the model and get corrected text.
32
+ """
33
+ prefixed = "grammar: " + text # T5-style task prefix
34
+
35
+ inputs = self.tokenizer(
36
+ prefixed,
37
+ return_tensors="pt",
38
+ truncation=True,
39
+ max_length=max_length,
40
+ ).to(self.device)
41
+
42
+ with torch.no_grad():
43
+ outputs = self.model.generate(
44
+ **inputs,
45
+ max_length=max_length,
46
+ num_beams=num_beams,
47
+ early_stopping=True,
48
+ )
49
+
50
+ corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return corrected.strip()
52
+
53
+ def _diff_explanation(self, original: str, corrected: str):
54
+ """
55
+ Create a simple, human-readable explanation of changes.
56
+ Returns a list of {type, from, to}.
57
+ """
58
+ diff = list(difflib.ndiff(original.split(), corrected.split()))
59
+ changes = []
60
+ current_del = []
61
+ current_add = []
62
+
63
+ for token in diff:
64
+ if token.startswith("- "):
65
+ current_del.append(token[2:])
66
+ elif token.startswith("+ "):
67
+ current_add.append(token[2:])
68
+ elif token.startswith(" "):
69
+ if current_del or current_add:
70
+ changes.append(
71
+ {
72
+ "from": " ".join(current_del) if current_del else None,
73
+ "to": " ".join(current_add) if current_add else None,
74
+ "type": self._infer_change_type(current_del, current_add),
75
+ }
76
+ )
77
+ current_del, current_add = [], []
78
+
79
+ if current_del or current_add:
80
+ changes.append(
81
+ {
82
+ "from": " ".join(current_del) if current_del else None,
83
+ "to": " ".join(current_add) if current_add else None,
84
+ "type": self._infer_change_type(current_del, current_add),
85
+ }
86
+ )
87
+
88
+ changes = [c for c in changes if c["from"] or c["to"]]
89
+ return changes
90
+
91
+ @staticmethod
92
+ def _infer_change_type(deleted_tokens, added_tokens):
93
+ """
94
+ Very simple heuristic for change type.
95
+ You can later improve this with more logic.
96
+ """
97
+ if deleted_tokens and not added_tokens:
98
+ return "deletion"
99
+ if added_tokens and not deleted_tokens:
100
+ return "insertion"
101
+ return "replacement"
102
+
103
+ def correct(self, text: str) -> dict:
104
+ """
105
+ Main method your system will call.
106
+ Returns a dict:
107
+ {
108
+ "original": ...,
109
+ "corrected": ...,
110
+ "changes": [ {type, from, to}, ... ],
111
+ "confidence": float,
112
+ "agent": "grammar"
113
+ }
114
+ """
115
+ corrected = self._generate(text)
116
+ changes = self._diff_explanation(text, corrected)
117
+
118
+ # simple heuristic confidence based on how much was changed
119
+ change_ratio = len(changes) / max(len(text.split()), 1)
120
+ confidence = max(0.3, 1.0 - change_ratio)
121
+
122
+ return {
123
+ "original": text,
124
+ "corrected": corrected,
125
+ "changes": changes,
126
+ "confidence": round(confidence, 2),
127
+ "agent": "grammar",
128
+ }