ehsanul007's picture
Update app.py
f74d886
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
)
QG_PRETRAINED = "ehsanul007/IAmA-question-generator"
SEQ_LENGTH = 512
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
qg_tokenizer = AutoTokenizer.from_pretrained(
QG_PRETRAINED, use_fast=False)
qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
qg_model.to(device)
qg_model.eval()
def evaluate(topic, context):
text = f'<topic> {topic} <context> {context}'
encoded_text = qg_tokenizer(
text,
padding='max_length',
max_length=SEQ_LENGTH,
truncation=True,
return_tensors="pt",
).to(device)
output = qg_model.generate(
input_ids=encoded_text["input_ids"],
max_length=100
)
question = qg_tokenizer.decode(
output[0],
skip_special_tokens=True
)
return question
g = gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2, label="Topic", placeholder="Topic Name."
),
gr.components.Textbox(lines=2, label="Your Details", placeholder="I am ..."),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="What AI wants to ask you",
)
],
title="IAmA Question Generator",
description="Write down who you are (add details) and which topic you want to be asked a question on",
)
g.queue(concurrency_count=1)
g.launch()