fuzzylab / app1.py
odaly's picture
rename app.py to app1.py
e26bdb4 verified
raw
history blame
1.75 kB
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer
from nltk.tokenize import sent_tokenize
import nltk
nltk.download('punkt')
# Load Pre-Trained Model And Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
def generate_response(text):
input_ids = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
response_ids = model.generate(input_ids=input_ids, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
return output
def format_messages_for_display(messages):
formatted_text = []
for message in messages:
if message["role"] == "assistant":
formatted_text.append(f"Assistant: {message['content']}")
else:
formatted_text.append(f"User: {message['content']}")
return "\n".join(formatted_text)
def main():
st.title("T5 Chat Interface")
if 'messages' not in st.session_state:
st.session_state['messages'] = []
with st.form(key='input_form'):
user_input = st.text_area("Enter your prompt:")
submitted = st.form_submit_button(label="Submit")
if submitted:
messages = [
{
"role": "user",
"content": user_input
}
]
response = generate_response(user_input)
st.session_state['messages'].append({
"role": "assistant",
"content": response
})
st.write(format_messages_for_display(st.session_state['messages']))
def save_session():
pass
if __name__ == '__main__':
main()