iamspruce commited on
Commit
f5d6f13
·
1 Parent(s): 3e24d97

updated the api

Browse files
Files changed (4) hide show
  1. app/models.py +67 -11
  2. app/prompts.py +83 -9
  3. app/routers/analyze.py +108 -26
  4. requirements.txt +1 -1
app/models.py CHANGED
@@ -1,39 +1,95 @@
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
2
  import torch
3
 
 
4
  device = torch.device("cpu")
5
 
6
- # Grammar model
 
 
7
  grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
8
  grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
9
 
10
- # FLAN-T5 for all prompts
 
 
11
  flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
12
  flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to(device)
13
 
14
- # Translation model
 
15
  trans_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE")
16
  trans_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE").to(device)
17
 
18
- def run_grammar_correction(text: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  inputs = grammar_tokenizer(f"fix: {text}", return_tensors="pt").to(device)
 
20
  outputs = grammar_model.generate(**inputs)
 
21
  return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
22
 
23
- def run_flan_prompt(prompt: str):
 
 
 
 
 
 
 
 
 
 
24
  inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
 
25
  outputs = flan_model.generate(**inputs)
 
26
  return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
27
 
28
- def run_translation(text: str, target_lang: str):
 
 
 
 
 
 
 
 
 
 
 
29
  inputs = trans_tokenizer(f">>{target_lang}<< {text}", return_tensors="pt").to(device)
 
30
  outputs = trans_model.generate(**inputs)
 
31
  return trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
32
 
 
 
 
33
 
34
- # Add this at the bottom of models.py
35
- tone_classifier = pipeline("text-classification", model="bhadresh-savani/bert-base-uncased-emotion", top_k=1)
36
 
37
- def classify_tone(text: str):
38
- result = tone_classifier(text)[0][0]
39
- return result['label']
 
 
 
 
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
2
  import torch
3
 
4
+ # Set the device for model inference (CPU is used by default)
5
  device = torch.device("cpu")
6
 
7
+ # --- Grammar model ---
8
+ # Uses vennify/t5-base-grammar-correction for grammar correction tasks.
9
+ # This model takes text and returns a grammatically corrected version.
10
  grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
11
  grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
12
 
13
+ # --- FLAN-T5 for all prompts ---
14
+ # Uses google/flan-t5-small for various text generation tasks based on prompts,
15
+ # such as paraphrasing, summarizing, and generating tone suggestions.
16
  flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
17
  flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to(device)
18
 
19
+ # --- Translation model ---
20
+ # Uses Helsinki-NLP/opus-mt-en-ROMANCE for English to Romance language translation.
21
  trans_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE")
22
  trans_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE").to(device)
23
 
24
+ # --- Tone classification model ---
25
+ # Uses j-hartmann/emotion-english-distilroberta-base for detecting emotions/tones
26
+ # within text. This provides a more nuanced analysis than simple positive/negative.
27
+ # 'top_k=1' ensures that only the most confident label is returned.
28
+ tone_classifier = pipeline("sentiment-analysis", model="j-hartmann/emotion-english-distilroberta-base", top_k=1)
29
+
30
+ def run_grammar_correction(text: str) -> str:
31
+ """
32
+ Corrects the grammar of the input text using the pre-trained T5 grammar model.
33
+
34
+ Args:
35
+ text (str): The input text to be grammatically corrected.
36
+
37
+ Returns:
38
+ str: The corrected text.
39
+ """
40
+ # Prepare the input for the grammar model by prefixing with "fix: "
41
  inputs = grammar_tokenizer(f"fix: {text}", return_tensors="pt").to(device)
42
+ # Generate the corrected output
43
  outputs = grammar_model.generate(**inputs)
44
+ # Decode the generated tokens back into a readable string, skipping special tokens
45
  return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
46
 
47
+ def run_flan_prompt(prompt: str) -> str:
48
+ """
49
+ Runs a given prompt through the FLAN-T5 model to generate a response.
50
+
51
+ Args:
52
+ prompt (str): The prompt string to be processed by FLAN-T5.
53
+
54
+ Returns:
55
+ str: The generated text response from FLAN-T5.
56
+ """
57
+ # Prepare the input for the FLAN-T5 model
58
  inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
59
+ # Generate the output based on the prompt
60
  outputs = flan_model.generate(**inputs)
61
+ # Decode the generated tokens back into a readable string
62
  return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
63
 
64
+ def run_translation(text: str, target_lang: str) -> str:
65
+ """
66
+ Translates the input text to the target language using the Helsinki-NLP translation model.
67
+
68
+ Args:
69
+ text (str): The input text to be translated.
70
+ target_lang (str): The target language code (e.g., "fr" for French).
71
+
72
+ Returns:
73
+ str: The translated text.
74
+ """
75
+ # Prepare the input for the translation model by specifying the target language
76
  inputs = trans_tokenizer(f">>{target_lang}<< {text}", return_tensors="pt").to(device)
77
+ # Generate the translated output
78
  outputs = trans_model.generate(**inputs)
79
+ # Decode the generated tokens back into a readable string
80
  return trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
81
 
82
+ def classify_tone(text: str) -> str:
83
+ """
84
+ Classifies the emotional tone of the input text using the pre-trained emotion classifier.
85
 
86
+ Args:
87
+ text (str): The input text for tone classification.
88
 
89
+ Returns:
90
+ str: The detected emotional label (e.g., 'neutral', 'joy', 'sadness').
91
+ """
92
+ # The tone_classifier returns a list of dictionaries, where each dictionary
93
+ # contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
94
+ result = tone_classifier(text)[0][0] # Access the first item in the list, then the first element of that list
95
+ return result['label']
app/prompts.py CHANGED
@@ -1,23 +1,97 @@
1
- def tone_prompt(text, tone):
 
 
 
 
 
 
 
 
 
 
2
  return f"Rewrite the following text in a {tone} tone: {text}"
3
 
4
- def clarity_prompt(text):
 
 
 
 
 
 
 
 
 
5
  return f"Make this clearer: {text}"
6
 
7
- def fluency_prompt(text):
 
 
 
 
 
 
 
 
 
8
  return f"Improve the fluency of this sentence: {text}"
9
 
10
- def paraphrase_prompt(text):
 
 
 
 
 
 
 
 
 
11
  return f"Paraphrase: {text}"
12
 
13
- def summarize_prompt(text):
 
 
 
 
 
 
 
 
 
14
  return f"Summarize: {text}"
15
 
16
- def pronoun_friendly_prompt(text):
17
- return f"Rewrite the text using inclusive and non-offensive pronouns: {text}"
 
 
 
 
 
 
 
 
 
 
18
 
19
- def active_voice_prompt(text):
 
 
 
 
 
 
 
 
 
20
  return f"Detect if this is passive or active voice. If passive, suggest an active voice version: {text}"
21
 
22
- def tone_analysis_prompt(text):
 
 
 
 
 
 
 
 
 
23
  return f"Analyze the tone of the following text and suggest improvements if needed: {text}"
 
1
+ def tone_prompt(text: str, tone: str) -> str:
2
+ """
3
+ Generates a prompt to rewrite text in a specified tone.
4
+
5
+ Args:
6
+ text (str): The original text.
7
+ tone (str): The desired tone (e.g., "formal", "informal", "confident").
8
+
9
+ Returns:
10
+ str: The generated prompt.
11
+ """
12
  return f"Rewrite the following text in a {tone} tone: {text}"
13
 
14
+ def clarity_prompt(text: str) -> str:
15
+ """
16
+ Generates a prompt to make text clearer.
17
+
18
+ Args:
19
+ text (str): The original text.
20
+
21
+ Returns:
22
+ str: The generated prompt.
23
+ """
24
  return f"Make this clearer: {text}"
25
 
26
+ def fluency_prompt(text: str) -> str:
27
+ """
28
+ Generates a prompt to improve the fluency of a sentence.
29
+
30
+ Args:
31
+ text (str): The original sentence.
32
+
33
+ Returns:
34
+ str: The generated prompt.
35
+ """
36
  return f"Improve the fluency of this sentence: {text}"
37
 
38
+ def paraphrase_prompt(text: str) -> str:
39
+ """
40
+ Generates a prompt to paraphrase text.
41
+
42
+ Args:
43
+ text (str): The original text.
44
+
45
+ Returns:
46
+ str: The generated prompt.
47
+ """
48
  return f"Paraphrase: {text}"
49
 
50
+ def summarize_prompt(text: str) -> str:
51
+ """
52
+ Generates a prompt to summarize text.
53
+
54
+ Args:
55
+ text (str): The original text.
56
+
57
+ Returns:
58
+ str: The generated prompt.
59
+ """
60
  return f"Summarize: {text}"
61
 
62
+ def pronoun_friendly_prompt(text: str) -> str:
63
+ """
64
+ Generates a prompt to rewrite text using inclusive, respectful language,
65
+ avoiding gender-specific pronouns.
66
+
67
+ Args:
68
+ text (str): The original text.
69
+
70
+ Returns:
71
+ str: The generated prompt.
72
+ """
73
+ return f"Rewrite the following text using inclusive, respectful language avoiding gender-specific pronouns: {text}"
74
 
75
+ def active_voice_prompt(text: str) -> str:
76
+ """
77
+ Generates a prompt to detect passive/active voice and suggest an active voice version if passive.
78
+
79
+ Args:
80
+ text (str): The original text.
81
+
82
+ Returns:
83
+ str: The generated prompt.
84
+ """
85
  return f"Detect if this is passive or active voice. If passive, suggest an active voice version: {text}"
86
 
87
+ def tone_analysis_prompt(text: str) -> str:
88
+ """
89
+ Generates a prompt to analyze the tone of text and suggest improvements.
90
+
91
+ Args:
92
+ text (str): The original text.
93
+
94
+ Returns:
95
+ str: The generated prompt.
96
+ """
97
  return f"Analyze the tone of the following text and suggest improvements if needed: {text}"
app/routers/analyze.py CHANGED
@@ -4,50 +4,132 @@ from app import models, prompts
4
  from app.core.security import verify_api_key
5
  import language_tool_python
6
  import spacy
 
7
 
8
  router = APIRouter()
 
 
 
9
  nlp = spacy.load("en_core_web_sm")
 
 
 
10
  tool = language_tool_python.LanguageTool('en-US')
11
 
12
  class AnalyzeInput(BaseModel):
 
 
 
 
13
  text: str
14
 
15
  @router.post("/analyze")
16
  def analyze_text(payload: AnalyzeInput, request: Request = Depends(verify_api_key)):
 
 
 
 
 
 
 
 
 
 
 
17
  text = payload.text
18
 
19
- # 1. Grammar Correction
20
- grammar = models.run_grammar_correction(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # 2. Punctuation Fixes
 
23
  matches = tool.check(text)
24
- punctuation_fixes = [m.message for m in matches if 'PUNCTUATION' in m.ruleId.upper()]
 
 
 
 
 
 
 
 
 
 
25
 
26
- # 3. Sentence Correctness Tips
27
- sentence_issues = [m.message for m in matches if 'PUNCTUATION' not in m.ruleId.upper()]
 
 
 
 
 
 
 
 
 
 
28
 
29
- # 4. Tone Detection
30
- tone_result = models.classify_tone(text)
31
- better_tone_version = models.run_flan_prompt(prompts.tone_prompt(text, "formal"))
32
 
33
- # 5. Active/Passive Voice
34
- doc = nlp(text)
35
- voice = "passive" if any(tok.dep_ == "auxpass" for tok in doc) else "active"
36
- if voice == "passive":
37
- better_voice = models.run_flan_prompt(f"Rewrite this in active voice: {text}")
38
- else:
39
- better_voice = "Already in active voice"
 
 
 
 
 
 
 
40
 
41
- # 6. Inclusive Pronoun Suggestion
42
- inclusive = models.run_flan_prompt(prompts.pronoun_friendly_prompt(text))
 
43
 
 
44
  return {
45
- "grammar": grammar,
46
- "punctuation_fixes": punctuation_fixes,
47
- "sentence_issues": sentence_issues,
48
- "tone": tone_result,
49
- "tone_suggestion": better_tone_version,
50
- "voice": voice,
51
- "voice_suggestion": better_voice,
52
- "inclusive_pronouns": inclusive
 
 
 
 
 
 
 
 
 
 
53
  }
 
4
  from app.core.security import verify_api_key
5
  import language_tool_python
6
  import spacy
7
+ import difflib # Import the difflib module for text comparisons
8
 
9
  router = APIRouter()
10
+
11
+ # Load the spaCy English language model for natural language processing tasks,
12
+ # such as dependency parsing for active/passive voice detection.
13
  nlp = spacy.load("en_core_web_sm")
14
+
15
+ # Initialize LanguageTool for grammar, spelling, and style checking.
16
+ # 'en-US' specifies the English (United States) language.
17
  tool = language_tool_python.LanguageTool('en-US')
18
 
19
  class AnalyzeInput(BaseModel):
20
+ """
21
+ Pydantic BaseModel for validating the input request body for the /analyze endpoint.
22
+ It expects a single field: 'text' (string).
23
+ """
24
  text: str
25
 
26
  @router.post("/analyze")
27
  def analyze_text(payload: AnalyzeInput, request: Request = Depends(verify_api_key)):
28
+ """
29
+ Analyzes the provided text for grammar, punctuation, sentence correctness,
30
+ tone, active/passive voice, and inclusive pronoun suggestions.
31
+
32
+ Args:
33
+ payload (AnalyzeInput): The request body containing the text to be analyzed.
34
+ request (Request): The FastAPI Request object (dependency injected for API key verification).
35
+
36
+ Returns:
37
+ dict: A dictionary containing various analysis results.
38
+ """
39
  text = payload.text
40
 
41
+ # --- 1. Grammar Suggestions with Diffs ---
42
+ # Get the grammatically corrected version of the original text.
43
+ corrected_grammar = models.run_grammar_correction(text)
44
+
45
+ # Use difflib to find differences between the original and corrected text.
46
+ # difflib.SequenceMatcher compares sequences and can identify insertions, deletions, and substitutions.
47
+ s = difflib.SequenceMatcher(None, text.split(), corrected_grammar.split())
48
+
49
+ grammar_changes = []
50
+ # Iterate through the operations (opcodes) generated by SequenceMatcher.
51
+ # 'equal', 'replace', 'delete', 'insert' are the types of operations.
52
+ for opcode, i1, i2, j1, j2 in s.get_opcodes():
53
+ if opcode == 'replace':
54
+ # If words are replaced, format as "'original_word' -> 'corrected_word'"
55
+ original_part = ' '.join(text.split()[i1:i2])
56
+ corrected_part = ' '.join(corrected_grammar.split()[j1:j2])
57
+ grammar_changes.append(f"'{original_part}' \u2192 '{corrected_part}'") # Using Unicode arrow
58
+ elif opcode == 'delete':
59
+ # If words are deleted, format as "'deleted_word' removed"
60
+ deleted_part = ' '.join(text.split()[i1:i2])
61
+ grammar_changes.append(f"'{deleted_part}' removed")
62
+ elif opcode == 'insert':
63
+ # If words are inserted, format as "'inserted_word' added"
64
+ inserted_part = ' '.join(corrected_grammar.split()[j1:j2])
65
+ grammar_changes.append(f"'{inserted_part}' added")
66
 
67
+ # --- 2. Punctuation Fixes and 3. Sentence Correctness Feedback ---
68
+ # LanguageTool checks the original text for various issues including punctuation.
69
  matches = tool.check(text)
70
+
71
+ punctuation_issues = []
72
+ sentence_correctness_feedback = []
73
+
74
+ for m in matches:
75
+ # Check if the rule ID contains "PUNCTUATION" to categorize it.
76
+ if 'PUNCTUATION' in m.ruleId.upper():
77
+ punctuation_issues.append(m.message)
78
+ else:
79
+ # All other issues are considered general sentence correctness feedback.
80
+ sentence_correctness_feedback.append(m.message)
81
 
82
+ # --- 4. Tone Detection and Suggestion ---
83
+ # Classify the tone of the original text using the fine-tuned model.
84
+ detected_tone = models.classify_tone(text)
85
+
86
+ tone_suggestion_text = ""
87
+ # Provide a simple tone suggestion based on the detected tone.
88
+ # This logic can be expanded for more sophisticated suggestions.
89
+ if detected_tone in ["neutral", "joy"]: # Example condition for suggesting a formal tone
90
+ # Generate a formal tone version using FLAN-T5.
91
+ tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "formal"))
92
+ else:
93
+ tone_suggestion_text = f"The detected tone '{detected_tone}' seems appropriate for general communication."
94
 
 
 
 
95
 
96
+ # --- 5. Active/Passive Voice Detection and Suggestion ---
97
+ doc = nlp(text) # Process the text with spaCy
98
+ voice_detected = "active"
99
+ voice_suggestion = "None \u2014 active voice is fine here." # Using Unicode em dash
100
+
101
+ # Iterate through tokens to find passive auxiliary verbs (e.g., "is", "was" in passive constructions).
102
+ # A simple heuristic: if any token's dependency is 'auxpass', it's likely passive.
103
+ for token in doc:
104
+ if token.dep_ == "auxpass":
105
+ voice_detected = "passive"
106
+ # If passive, ask FLAN-T5 to rewrite it in active voice.
107
+ better_voice_prompt = prompts.active_voice_prompt(text)
108
+ voice_suggestion = models.run_flan_prompt(better_voice_prompt)
109
+ break # Exit loop once passive voice is detected
110
 
111
+ # --- 6. Inclusive Pronoun Suggestion ---
112
+ # Use FLAN-T5 with a specific prompt to suggest inclusive language.
113
+ inclusive_pronouns_suggestion = models.run_flan_prompt(prompts.pronoun_friendly_prompt(text))
114
 
115
+ # --- Construct the final response matching the example output structure ---
116
  return {
117
+ "grammar": {
118
+ "corrected": corrected_grammar,
119
+ "changes": grammar_changes
120
+ },
121
+ "punctuation": {
122
+ "issues": punctuation_issues,
123
+ "suggestions": [] # The grammar correction and diffs implicitly handle suggestions here
124
+ },
125
+ "sentence_correctness": sentence_correctness_feedback,
126
+ "tone_analysis": {
127
+ "detected": detected_tone,
128
+ "suggestion": tone_suggestion_text
129
+ },
130
+ "voice": {
131
+ "detected": voice_detected,
132
+ "suggestion": voice_suggestion
133
+ },
134
+ "inclusive_pronouns": inclusive_pronouns_suggestion
135
  }
requirements.txt CHANGED
@@ -7,4 +7,4 @@ pyspellchecker
7
  spacy
8
  nltk
9
  language-tool-python
10
- scikit-learn
 
7
  spacy
8
  nltk
9
  language-tool-python
10
+ scikit-learn