iamspruce commited on
Commit
c37cee4
·
1 Parent(s): 8d34c33

updated the api

Browse files
Files changed (1) hide show
  1. app/models.py +16 -22
app/models.py CHANGED
@@ -6,10 +6,10 @@ import torch
6
  device = torch.device("cpu")
7
 
8
  # --- Grammar model ---
9
- # Changed to humarin/t5-small-grammar-correction for potentially better performance
10
- # on common spelling and grammar issues compared to vennify/t5-base-grammar-correction.
11
- grammar_tokenizer = AutoTokenizer.from_pretrained("humarin/t5-small-grammar-correction")
12
- grammar_model = AutoModelForSeq2SeqLM.from_pretrained("humarin/t5-small-grammar-correction").to(device)
13
 
14
  # --- FLAN-T5 for all prompts ---
15
  # Uses google/flan-t5-small for various text generation tasks based on prompts,
@@ -38,18 +38,12 @@ def run_grammar_correction(text: str) -> str:
38
  Returns:
39
  str: The corrected text.
40
  """
41
- # Prepare the input for the grammar model by prefixing with "fix: "
42
- inputs = grammar_tokenizer(f"fix: {text}", return_tensors="pt").to(device)
 
 
43
  # Generate the corrected output
44
- outputs = grammar_model.generate(
45
- **inputs,
46
- max_new_tokens=50, # adjust as needed
47
- num_beams=5,
48
- do_sample=True,
49
- top_k=50,
50
- top_p=0.95,
51
- temperature=0.7
52
- )
53
  # Decode the generated tokens back into a readable string, skipping special tokens
54
  return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
55
 
@@ -68,14 +62,14 @@ def run_flan_prompt(prompt: str) -> str:
68
  inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
69
 
70
  # Generate the output with improved parameters:
71
- # max_new_tokens: Limits the maximum length of the generated response.
72
- # num_beams: Uses beam search for higher quality, less repetitive outputs.
73
- # temperature: Controls randomness; lower means more deterministic.
74
  outputs = flan_model.generate(
75
  **inputs,
76
- max_new_tokens=100, # Limit output length to prevent rambling
77
- num_beams=5, # Use beam search for better quality
78
- temperature=0.7 # Controls randomness; lower means more deterministic
 
 
 
79
  )
80
  # Decode the generated tokens back into a readable string
81
  return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -110,5 +104,5 @@ def classify_tone(text: str) -> str:
110
  """
111
  # The tone_classifier returns a list of dictionaries, where each dictionary
112
  # contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
113
- result = tone_classifier(text)[0][0] # Access the first item in the list, then the first element of that list
114
  return result['label']
 
6
  device = torch.device("cpu")
7
 
8
  # --- Grammar model ---
9
+ # Changed to deepashri/t5-small-grammar-correction, a publicly available model
10
+ # for grammatical error correction. This model is fine-tuned from T5-small.
11
+ grammar_tokenizer = AutoTokenizer.from_pretrained("deepashri/t5-small-grammar-correction")
12
+ grammar_model = AutoModelForSeq2SeqLM.from_pretrained("deepashri/t5-small-grammar-correction").to(device)
13
 
14
  # --- FLAN-T5 for all prompts ---
15
  # Uses google/flan-t5-small for various text generation tasks based on prompts,
 
38
  Returns:
39
  str: The corrected text.
40
  """
41
+ # Prepare the input for the grammar model by prefixing with "grammar: " as per
42
+ # the 'deepashri/t5-small-grammar-correction' model's expected input format.
43
+ # Some grammar correction models expect a specific prefix like "grammar: " or "fix: ".
44
+ inputs = grammar_tokenizer(f"grammar: {text}", return_tensors="pt").to(device)
45
  # Generate the corrected output
46
+ outputs = grammar_model.generate(**inputs)
 
 
 
 
 
 
 
 
47
  # Decode the generated tokens back into a readable string, skipping special tokens
48
  return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
 
 
62
  inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
63
 
64
  # Generate the output with improved parameters:
 
 
 
65
  outputs = flan_model.generate(
66
  **inputs,
67
+ max_new_tokens=100,
68
+ num_beams=5,
69
+ do_sample=True,
70
+ top_k=50,
71
+ top_p=0.95,
72
+ temperature=0.7
73
  )
74
  # Decode the generated tokens back into a readable string
75
  return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
104
  """
105
  # The tone_classifier returns a list of dictionaries, where each dictionary
106
  # contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
107
+ result = tone_classifier(text)[0][0]
108
  return result['label']