jeevana's picture
Update app.py
4ad1def verified
raw
history blame contribute delete
839 Bytes
import gradio as gr
import torch
from transformers import pipeline
#pipeline = pipeline(task="text-generation", model="EmailSubjectLineGeneration")
pipeline = pipeline(model="jeevana/EmailSubjectLineGeneration", max_new_tokens=20)
def predict(input):
print("pipeline object", pipeline)
prediction = pipeline(input+'\n@subject\n')
prediction = prediction[0].get("generated_text")
print("1:::", prediction)
prediction = prediction[len(input)+len('\n@subject\n'):]
return prediction
app = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Email", lines=15)],
outputs=[gr.Textbox(label="Subject", lines=15)],
title="EmailSubjectLineGeneration",
description="EmailSubjectLineGeneration with GPT2"
)
app.launch(share=True, debug=True)