Furqan1111 commited on
Commit
b346451
·
verified ·
1 Parent(s): b482a43

Create style_agent.py

Browse files
Files changed (1) hide show
  1. agents/style_agent.py +125 -0
agents/style_agent.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import difflib
5
+
6
+
7
+ class StyleAgent:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "rajistics/informal_formal_style_transfer",
11
+ device: Optional[str] = None,
12
+ ):
13
+ """
14
+ Style Agent
15
+ - model_name: HF model id for informal -> formal style transfer
16
+ - device: "cuda" or "cpu" (auto-detect if None)
17
+ """
18
+ self.model_name = model_name
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+
22
+ if device is None:
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ else:
25
+ self.device = device
26
+
27
+ self.model.to(self.device)
28
+
29
+ def _generate(self, text: str, max_length: int = 256, num_beams: int = 5) -> str:
30
+ """
31
+ Internal helper to call the model and get a more formal / professional version.
32
+ """
33
+ # Many style-transfer T5 models work directly on raw text.
34
+ # If the model card suggests a prefix, add it here, e.g.:
35
+ # text = "formal: " + text
36
+
37
+ inputs = self.tokenizer(
38
+ text,
39
+ return_tensors="pt",
40
+ truncation=True,
41
+ max_length=max_length,
42
+ ).to(self.device)
43
+
44
+ with torch.no_grad():
45
+ outputs = self.model.generate(
46
+ **inputs,
47
+ max_length=max_length,
48
+ num_beams=num_beams,
49
+ early_stopping=True,
50
+ )
51
+
52
+ styled = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return styled.strip()
54
+
55
+ def _diff_explanation(self, original: str, styled: str):
56
+ """
57
+ Compare original vs styled sentence and return simple word-level changes.
58
+ """
59
+ diff = list(difflib.ndiff(original.split(), styled.split()))
60
+ changes = []
61
+ current_del = []
62
+ current_add = []
63
+
64
+ for token in diff:
65
+ if token.startswith("- "):
66
+ current_del.append(token[2:])
67
+ elif token.startswith("+ "):
68
+ current_add.append(token[2:])
69
+ elif token.startswith(" "):
70
+ if current_del or current_add:
71
+ changes.append(
72
+ {
73
+ "from": " ".join(current_del) if current_del else None,
74
+ "to": " ".join(current_add) if current_add else None,
75
+ "type": self._infer_change_type(current_del, current_add),
76
+ }
77
+ )
78
+ current_del, current_add = [], []
79
+
80
+ if current_del or current_add:
81
+ changes.append(
82
+ {
83
+ "from": " ".join(current_del) if current_del else None,
84
+ "to": " ".join(current_add) if current_add else None,
85
+ "type": self._infer_change_type(current_del, current_add),
86
+ }
87
+ )
88
+
89
+ changes = [c for c in changes if c["from"] or c["to"]]
90
+ return changes
91
+
92
+ @staticmethod
93
+ def _infer_change_type(deleted_tokens, added_tokens):
94
+ if deleted_tokens and not added_tokens:
95
+ return "deletion"
96
+ if added_tokens and not deleted_tokens:
97
+ return "insertion"
98
+ return "replacement"
99
+
100
+ def stylize(self, text: str) -> dict:
101
+ """
102
+ Main method for TextDoctor.
103
+ Returns:
104
+ {
105
+ "original": ...,
106
+ "styled": ...,
107
+ "changes": [ {type, from, to}, ... ],
108
+ "confidence": float,
109
+ "agent": "style"
110
+ }
111
+ """
112
+ styled = self._generate(text)
113
+ changes = self._diff_explanation(text, styled)
114
+
115
+ # simple heuristic confidence
116
+ change_ratio = len(changes) / max(len(text.split()), 1)
117
+ confidence = max(0.3, 1.0 - change_ratio)
118
+
119
+ return {
120
+ "original": text,
121
+ "styled": styled,
122
+ "changes": changes,
123
+ "confidence": round(confidence, 2),
124
+ "agent": "style",
125
+ }