Furqan1111 commited on
Commit
dbfd080
·
verified ·
1 Parent(s): b14ccb8

Create clarity_agent.py

Browse files
Files changed (1) hide show
  1. agents/clarity_agent.py +134 -0
agents/clarity_agent.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import difflib
5
+
6
+
7
+ class ClarityAgent:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "Vamsi/T5_Paraphrase_Paws",
11
+ device: Optional[str] = None,
12
+ ):
13
+ """
14
+ Clarity Agent
15
+ - Uses a paraphrasing model to restate sentences more clearly.
16
+ - model_name: Hugging Face model ID.
17
+ - device: "cuda" or "cpu" (auto-detect if None).
18
+ """
19
+ self.model_name = model_name
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
21
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
+
23
+ if device is None:
24
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ else:
26
+ self.device = device
27
+
28
+ self.model.to(self.device)
29
+
30
+ def _generate(
31
+ self,
32
+ text: str,
33
+ max_length: int = 256,
34
+ num_beams: int = 5,
35
+ num_return_sequences: int = 1,
36
+ ) -> str:
37
+ """
38
+ Internal helper to generate a clearer paraphrase of the input sentence.
39
+ """
40
+ # Many T5 paraphrase models expect a prefix like "paraphrase: "
41
+ prefixed = "paraphrase: " + text + " </s>"
42
+
43
+ inputs = self.tokenizer(
44
+ [prefixed],
45
+ max_length=max_length,
46
+ padding="longest",
47
+ truncation=True,
48
+ return_tensors="pt",
49
+ ).to(self.device)
50
+
51
+ with torch.no_grad():
52
+ outputs = self.model.generate(
53
+ **inputs,
54
+ max_length=max_length,
55
+ num_beams=num_beams,
56
+ num_return_sequences=num_return_sequences,
57
+ early_stopping=True,
58
+ )
59
+
60
+ # Take the first generated sequence
61
+ paraphrased = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+ return paraphrased.strip()
63
+
64
+ def _diff_explanation(self, original: str, clarified: str):
65
+ """
66
+ Compare original vs clarified sentence and return simple word-level changes.
67
+ """
68
+ diff = list(difflib.ndiff(original.split(), clarified.split()))
69
+ changes = []
70
+ current_del = []
71
+ current_add = []
72
+
73
+ for token in diff:
74
+ if token.startswith("- "):
75
+ current_del.append(token[2:])
76
+ elif token.startswith("+ "):
77
+ current_add.append(token[2:])
78
+ elif token.startswith(" "):
79
+ if current_del or current_add:
80
+ changes.append(
81
+ {
82
+ "from": " ".join(current_del) if current_del else None,
83
+ "to": " ".join(current_add) if current_add else None,
84
+ "type": self._infer_change_type(current_del, current_add),
85
+ }
86
+ )
87
+ current_del, current_add = [], []
88
+
89
+ if current_del or current_add:
90
+ changes.append(
91
+ {
92
+ "from": " ".join(current_del) if current_del else None,
93
+ "to": " ".join(current_add) if current_add else None,
94
+ "type": self._infer_change_type(current_del, current_add),
95
+ }
96
+ )
97
+
98
+ changes = [c for c in changes if c["from"] or c["to"]]
99
+ return changes
100
+
101
+ @staticmethod
102
+ def _infer_change_type(deleted_tokens, added_tokens):
103
+ if deleted_tokens and not added_tokens:
104
+ return "deletion"
105
+ if added_tokens and not deleted_tokens:
106
+ return "insertion"
107
+ return "replacement"
108
+
109
+ def clarify(self, text: str) -> dict:
110
+ """
111
+ Main method for TextDoctor.
112
+ Returns:
113
+ {
114
+ "original": ...,
115
+ "clarified": ...,
116
+ "changes": [ {type, from, to}, ... ],
117
+ "confidence": float,
118
+ "agent": "clarity"
119
+ }
120
+ """
121
+ clarified = self._generate(text)
122
+ changes = self._diff_explanation(text, clarified)
123
+
124
+ # simple heuristic confidence
125
+ change_ratio = len(changes) / max(len(text.split()), 1)
126
+ confidence = max(0.3, 1.0 - change_ratio)
127
+
128
+ return {
129
+ "original": text,
130
+ "clarified": clarified,
131
+ "changes": changes,
132
+ "confidence": round(confidence, 2),
133
+ "agent": "clarity",
134
+ }