Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
| import torch | |
| # Set the device for model inference (CPU is used by default) | |
| device = torch.device("cpu") | |
| # --- Grammar model --- | |
| # Uses vennify/t5-base-grammar-correction for grammar correction tasks. | |
| # This model takes text and returns a grammatically corrected version. | |
| grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction") | |
| grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device) | |
| # --- FLAN-T5 for all prompts --- | |
| # Uses google/flan-t5-small for various text generation tasks based on prompts, | |
| # such as paraphrasing, summarizing, and generating tone suggestions. | |
| flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") | |
| flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to(device) | |
| # --- Translation model --- | |
| # Uses Helsinki-NLP/opus-mt-en-ROMANCE for English to Romance language translation. | |
| trans_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE") | |
| trans_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ROMANCE").to(device) | |
| # --- Tone classification model --- | |
| # Uses j-hartmann/emotion-english-distilroberta-base for detecting emotions/tones | |
| # within text. This provides a more nuanced analysis than simple positive/negative. | |
| # 'top_k=1' ensures that only the most confident label is returned. | |
| tone_classifier = pipeline("sentiment-analysis", model="j-hartmann/emotion-english-distilroberta-base", top_k=1) | |
| def run_grammar_correction(text: str) -> str: | |
| """ | |
| Corrects the grammar of the input text using the pre-trained T5 grammar model. | |
| Args: | |
| text (str): The input text to be grammatically corrected. | |
| Returns: | |
| str: The corrected text. | |
| """ | |
| # Prepare the input for the grammar model by prefixing with "fix: " | |
| inputs = grammar_tokenizer(f"fix: {text}", return_tensors="pt").to(device) | |
| # Generate the corrected output | |
| outputs = grammar_model.generate(**inputs) | |
| # Decode the generated tokens back into a readable string, skipping special tokens | |
| return grammar_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def run_flan_prompt(prompt: str) -> str: | |
| """ | |
| Runs a given prompt through the FLAN-T5 model to generate a response. | |
| Args: | |
| prompt (str): The prompt string to be processed by FLAN-T5. | |
| Returns: | |
| str: The generated text response from FLAN-T5. | |
| """ | |
| # Prepare the input for the FLAN-T5 model | |
| inputs = flan_tokenizer(prompt, return_tensors="pt").to(device) | |
| # Generate the output based on the prompt | |
| outputs = flan_model.generate(**inputs) | |
| # Decode the generated tokens back into a readable string | |
| return flan_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def run_translation(text: str, target_lang: str) -> str: | |
| """ | |
| Translates the input text to the target language using the Helsinki-NLP translation model. | |
| Args: | |
| text (str): The input text to be translated. | |
| target_lang (str): The target language code (e.g., "fr" for French). | |
| Returns: | |
| str: The translated text. | |
| """ | |
| # Prepare the input for the translation model by specifying the target language | |
| inputs = trans_tokenizer(f">>{target_lang}<< {text}", return_tensors="pt").to(device) | |
| # Generate the translated output | |
| outputs = trans_model.generate(**inputs) | |
| # Decode the generated tokens back into a readable string | |
| return trans_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def classify_tone(text: str) -> str: | |
| """ | |
| Classifies the emotional tone of the input text using the pre-trained emotion classifier. | |
| Args: | |
| text (str): The input text for tone classification. | |
| Returns: | |
| str: The detected emotional label (e.g., 'neutral', 'joy', 'sadness'). | |
| """ | |
| # The tone_classifier returns a list of dictionaries, where each dictionary | |
| # contains 'label' and 'score'. We extract the 'label' from the first (and only) result. | |
| result = tone_classifier(text)[0][0] # Access the first item in the list, then the first element of that list | |
| return result['label'] | |