| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| from gradio import Interface |
|
|
| |
| model_name = "facebook/bart-base" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
| def generate_questions(email): |
| """Generates questions based on the input email.""" |
| |
| inputs = tokenizer(email, return_tensors="pt") |
| |
| if len(inputs["input_ids"]) > 512: |
| print("WARNING: Input sequence exceeds maximum length. Truncating.") |
| inputs["input_ids"] = inputs["input_ids"][:512] |
|
|
| |
| generation = model.generate( |
| **inputs, |
| max_length=256, |
| |
| ) |
|
|
| |
|
|
| |
| return tokenizer.decode(generation[0], skip_special_tokens=True) |
|
|
|
|
|
|
| def generate_answers(questions): |
| """Generates possible answers to the input questions.""" |
| |
| inputs = tokenizer("\n".join(questions), return_tensors="pt") |
|
|
| |
| generation = model.generate( |
| input_ids=inputs["input_ids"], |
| max_length=512, |
| num_beams=3, |
| early_stopping=True, |
| prompt="Here are some possible answers to the questions:\n", |
| ) |
|
|
| |
| answers = tokenizer.decode(generation[0], skip_special_tokens=True).split("\n") |
| return zip(questions, answers[1:]) |
|
|
| def gradio_app(email): |
| """Gradio interface function""" |
| questions = generate_questions(email) |
| answers = generate_answers(questions.split("\n")) |
| return questions, [answer for _, answer in answers] |
|
|
| |
| |
| interface = Interface( |
| fn=gradio_app, |
| inputs="textbox", |
| outputs=["text", "text"], |
| title="AI Email Assistant", |
| description="Enter a long email and get questions and possible answers generated by an AI model.", |
| elem_id="email-input" |
| ) |
|
|
|
|
| |
| interface.launch() |
|
|