iamspruce commited on
Commit
8d34c33
·
1 Parent(s): 1cfeb58

updated the api

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -0
  2. app/models.py +14 -12
Dockerfile CHANGED
@@ -22,6 +22,9 @@ RUN python -m spacy download en_core_web_sm
22
  ENV HF_HOME=/cache
23
  RUN mkdir -p /cache && chmod -R 777 /cache
24
 
 
 
 
25
  # Explicitly create and set permissions for /.cache
26
  # This is to address PermissionError: [Errno 13] Permission denied: '/.cache'
27
  # which language-tool-python might be trying to write to.
 
22
  ENV HF_HOME=/cache
23
  RUN mkdir -p /cache && chmod -R 777 /cache
24
 
25
+ # ... other Dockerfile content ...
26
+ ENV GRAMMAFREE_API_KEY="admin"
27
+
28
  # Explicitly create and set permissions for /.cache
29
  # This is to address PermissionError: [Errno 13] Permission denied: '/.cache'
30
  # which language-tool-python might be trying to write to.
app/models.py CHANGED
@@ -6,12 +6,10 @@ import torch
6
  device = torch.device("cpu")
7
 
8
  # --- Grammar model ---
9
- # Uses vennify/t5-base-grammar-correction for grammar correction tasks.
10
- # Note: This model might not catch all subtle spelling or advanced grammar errors
11
- # as robustly as larger models or rule-based systems. Its performance depends on
12
- # its training data.
13
- grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
14
- grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
15
 
16
  # --- FLAN-T5 for all prompts ---
17
  # Uses google/flan-t5-small for various text generation tasks based on prompts,
@@ -43,7 +41,15 @@ def run_grammar_correction(text: str) -> str:
43
  # Prepare the input for the grammar model by prefixing with "fix: "
44
  inputs = grammar_tokenizer(f"fix: {text}", return_tensors="pt").to(device)
45
  # Generate the corrected output
46
- outputs = grammar_model.generate(**inputs)
 
 
 
 
 
 
 
 
47
  # Decode the generated tokens back into a readable string, skipping special tokens
48
  return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
 
@@ -64,15 +70,11 @@ def run_flan_prompt(prompt: str) -> str:
64
  # Generate the output with improved parameters:
65
  # max_new_tokens: Limits the maximum length of the generated response.
66
  # num_beams: Uses beam search for higher quality, less repetitive outputs.
67
- # do_sample: Enables sampling, allowing for more diverse outputs.
68
- # top_k, top_p: Control the sampling process, making it more focused and coherent.
69
  outputs = flan_model.generate(
70
  **inputs,
71
  max_new_tokens=100, # Limit output length to prevent rambling
72
  num_beams=5, # Use beam search for better quality
73
- do_sample=True, # Enable sampling for diversity
74
- top_k=50, # Sample from top 50 most probable tokens
75
- top_p=0.95, # Sample from tokens that cumulatively exceed 95% probability
76
  temperature=0.7 # Controls randomness; lower means more deterministic
77
  )
78
  # Decode the generated tokens back into a readable string
 
6
  device = torch.device("cpu")
7
 
8
  # --- Grammar model ---
9
+ # Changed to humarin/t5-small-grammar-correction for potentially better performance
10
+ # on common spelling and grammar issues compared to vennify/t5-base-grammar-correction.
11
+ grammar_tokenizer = AutoTokenizer.from_pretrained("humarin/t5-small-grammar-correction")
12
+ grammar_model = AutoModelForSeq2SeqLM.from_pretrained("humarin/t5-small-grammar-correction").to(device)
 
 
13
 
14
  # --- FLAN-T5 for all prompts ---
15
  # Uses google/flan-t5-small for various text generation tasks based on prompts,
 
41
  # Prepare the input for the grammar model by prefixing with "fix: "
42
  inputs = grammar_tokenizer(f"fix: {text}", return_tensors="pt").to(device)
43
  # Generate the corrected output
44
+ outputs = grammar_model.generate(
45
+ **inputs,
46
+ max_new_tokens=50, # adjust as needed
47
+ num_beams=5,
48
+ do_sample=True,
49
+ top_k=50,
50
+ top_p=0.95,
51
+ temperature=0.7
52
+ )
53
  # Decode the generated tokens back into a readable string, skipping special tokens
54
  return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
55
 
 
70
  # Generate the output with improved parameters:
71
  # max_new_tokens: Limits the maximum length of the generated response.
72
  # num_beams: Uses beam search for higher quality, less repetitive outputs.
73
+ # temperature: Controls randomness; lower means more deterministic.
 
74
  outputs = flan_model.generate(
75
  **inputs,
76
  max_new_tokens=100, # Limit output length to prevent rambling
77
  num_beams=5, # Use beam search for better quality
 
 
 
78
  temperature=0.7 # Controls randomness; lower means more deterministic
79
  )
80
  # Decode the generated tokens back into a readable string