Shivangguptasih commited on
Commit
624eb07
·
verified ·
1 Parent(s): f2ded98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ import os
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ import re
8
+
9
+ app = FastAPI(title="Caption Simplifier API")
10
+
11
+ class RobustSimplifier:
12
+ def __init__(self):
13
+ # Load model from local files
14
+ self.tokenizer = AutoTokenizer.from_pretrained("./model")
15
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("./model")
16
+
17
+ self.simplification_rules = {
18
+ "implementation of": "",
19
+ "utilization of": "",
20
+ "revolutionized": "changed",
21
+ "enhanced": "improved",
22
+ "launched": "started",
23
+ "significant": "big",
24
+ "remarkable": "great",
25
+ "immediate": "urgent",
26
+ "breakthrough": "finding",
27
+ "methodologies": "methods",
28
+ "artificial intelligence": "AI",
29
+ "data processing": "data work",
30
+ "medical attention": "medical help",
31
+ "cancer treatment": "cancer care",
32
+ "quantum physics": "quantum science",
33
+ "has revolutionized": "changed",
34
+ "has enhanced": "improved",
35
+ "has launched": "started",
36
+ "have discovered": "found",
37
+ "have made": "created",
38
+ "needs immediate": "needs urgent",
39
+ "the government": "government",
40
+ "the researchers": "researchers",
41
+ "the scientists": "scientists",
42
+ "the doctors": "doctors",
43
+ "the patient": "patient"
44
+ }
45
+
46
+ self.words_to_remove = {
47
+ "the", "a", "an", "has", "have", "been", "is", "are", "was", "were",
48
+ "of", "in", "on", "at", "by", "for", "with", "to", "from"
49
+ }
50
+
51
+ def simplify(self, text):
52
+ rule_result = self.rule_based_simplify(text)
53
+
54
+ if self.is_good_simplification(text, rule_result):
55
+ return rule_result
56
+
57
+ model_result = self.get_model_simplification(text)
58
+
59
+ if self.is_good_simplification(text, model_result) and not self.has_hallucination(text, model_result):
60
+ return model_result
61
+
62
+ return rule_result
63
+
64
+ def rule_based_simplify(self, text):
65
+ result = text.lower()
66
+
67
+ for old, new in self.simplification_rules.items():
68
+ result = result.replace(old, new)
69
+
70
+ words = result.split()
71
+ important_words = [word for word in words if word not in self.words_to_remove]
72
+
73
+ if len(important_words) > 7:
74
+ important_words = important_words[:7]
75
+
76
+ result = " ".join(important_words)
77
+ result = re.sub(r'\s+', ' ', result).strip()
78
+
79
+ if result:
80
+ result = result[0].upper() + result[1:]
81
+
82
+ return result
83
+
84
+ def get_model_simplification(self, text):
85
+ try:
86
+ inputs = self.tokenizer(f"simplify: {text}", return_tensors="pt")
87
+
88
+ outputs = self.model.generate(
89
+ inputs.input_ids,
90
+ max_length=24,
91
+ num_beams=1,
92
+ do_sample=False
93
+ )
94
+
95
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
96
+ result = result.strip()
97
+ if result.startswith("simplify: "):
98
+ result = result[10:]
99
+
100
+ return result.strip()
101
+ except:
102
+ return ""
103
+
104
+ def is_good_simplification(self, original, simplified):
105
+ if len(simplified) >= len(original):
106
+ return False
107
+ if len(simplified) < 3:
108
+ return False
109
+ if not re.match(r'^[a-zA-Z\s]+$', simplified):
110
+ return False
111
+ if len(simplified) > len(original) * 0.7:
112
+ return False
113
+ return True
114
+
115
+ def has_hallucination(self, original, simplified):
116
+ original_words = set(original.lower().split())
117
+ simplified_words = set(simplified.lower().split())
118
+ new_words = simplified_words - original_words
119
+
120
+ if len(new_words) > 2:
121
+ return True
122
+
123
+ hallucination_patterns = [
124
+ "disabilities", "cancer cells are more", "more susceptible to",
125
+ "simplify the rules", "than cancer cells", "government rules"
126
+ ]
127
+
128
+ simplified_lower = simplified.lower()
129
+ for pattern in hallucination_patterns:
130
+ if pattern in simplified_lower:
131
+ return True
132
+
133
+ return False
134
+
135
+ # Initialize model
136
+ simplifier = RobustSimplifier()
137
+
138
+ class SimplifyRequest(BaseModel):
139
+ text: str
140
+ language: Optional[str] = "en"
141
+
142
+ class SimplifyResponse(BaseModel):
143
+ original: str
144
+ simplified: str
145
+ language: str
146
+
147
+ @app.get("/")
148
+ def read_root():
149
+ return {"status": "healthy", "message": "Caption Simplifier API is running"}
150
+
151
+ @app.post("/simplify", response_model=SimplifyResponse)
152
+ def simplify_text(request: SimplifyRequest):
153
+ simplified = simplifier.simplify(request.text)
154
+
155
+ return SimplifyResponse(
156
+ original=request.text,
157
+ simplified=simplified,
158
+ language=request.language
159
+ )
160
+
161
+ @app.get("/test")
162
+ def test_api():
163
+ test_text = "The government has launched a new scheme for improving education quality."
164
+ simplified = simplifier.simplify(test_text)
165
+
166
+ return {
167
+ "test_input": test_text,
168
+ "test_output": simplified,
169
+ "status": "success"
170
+ }