Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| from nltk.tokenize import sent_tokenize | |
| import nltk | |
| nltk.download('punkt') | |
| # Load Pre-Trained Model And Tokenizer | |
| tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
| model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
| def generate_response(text): | |
| input_ids = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True) | |
| response_ids = model.generate(input_ids=input_ids, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) | |
| output = tokenizer.decode(response_ids[0], skip_special_tokens=True) | |
| return output | |
| def format_messages_for_display(messages): | |
| formatted_text = [] | |
| for message in messages: | |
| if message["role"] == "assistant": | |
| formatted_text.append(f"Assistant: {message['content']}") | |
| else: | |
| formatted_text.append(f"User: {message['content']}") | |
| return "\n".join(formatted_text) | |
| def main(): | |
| st.title("T5 Chat Interface") | |
| if 'messages' not in st.session_state: | |
| st.session_state['messages'] = [] | |
| with st.form(key='input_form'): | |
| user_input = st.text_area("Enter your prompt:") | |
| submitted = st.form_submit_button(label="Submit") | |
| if submitted: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": user_input | |
| } | |
| ] | |
| response = generate_response(user_input) | |
| st.session_state['messages'].append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| st.write(format_messages_for_display(st.session_state['messages'])) | |
| def save_session(): | |
| pass | |
| if __name__ == '__main__': | |
| main() | |