odaly commited on
Commit
6cd3a3a
·
verified ·
1 Parent(s): 3e22531

create mainpy

Browse files
Files changed (1) hide show
  1. main.py +121 -0
main.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ from nltk.tokenize import sent_tokenize
4
+ import nltk
5
+ import json
6
+ import os
7
+ import time
8
+
9
+ nltk.download('punkt')
10
+
11
+ # Load Pre-Trained Model And Tokenizer
12
+ tokenizer = T5Tokenizer.from_pretrained("t5-base")
13
+ model = T5ForConditionalGeneration.from_pretrained("t5-base")
14
+
15
+ def response_generator(msg_content):
16
+ lines = msg_content.split('\n')
17
+ for line in lines:
18
+ words = line.split()
19
+ for word in words:
20
+ yield word + " "
21
+ time.sleep(0.1)
22
+ yield "\n"
23
+
24
+ def show_msgs():
25
+ for msg in st.session_state.messages:
26
+ role = msg["role"]
27
+ with st.chat_message(role):
28
+ st.write(msg["content"])
29
+
30
+ def generate_response(text):
31
+ input_ids = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
32
+ response_ids = model.generate(input_ids=input_ids, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
33
+ output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
34
+ return output
35
+
36
+ def format_messages_for_summary(messages):
37
+ return '\n'.join(f"{msg['role']}: {msg['content']}" for msg in messages)
38
+
39
+ def save_chat():
40
+ if not os.path.exists('./Intermediate-Chats'):
41
+ os.makedirs('./Intermediate-Chats')
42
+ if st.session_state['messages']:
43
+ formatted_messages = format_messages_for_summary(st.session_state['messages'])
44
+ filename = f'./Intermediate-Chats/chat_{int(time.time())}.txt'
45
+ with open(filename, 'w') as f:
46
+ for message in st.session_state['messages']:
47
+ encoded_content = message['content'].replace('\n', '\\n')
48
+ f.write(f"{message['role']}: {encoded_content}\n")
49
+ st.session_state['messages'].clear()
50
+ else:
51
+ st.warning("No chat messages to save.")
52
+
53
+ def load_saved_chats():
54
+ chat_dir = './Intermediate-Chats'
55
+ if os.path.exists(chat_dir):
56
+ files = os.listdir(chat_dir)
57
+ files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True)
58
+ for file_name in files:
59
+ display_name = file_name[:-4] if file_name.endswith('.txt') else file_name
60
+ if st.sidebar.button(display_name):
61
+ st.session_state['show_chats'] = False
62
+ st.session_state['is_loaded'] = True
63
+ load_chat(os.path.join(chat_dir, file_name))
64
+
65
+ def load_chat(file_path):
66
+ st.session_state['messages'].clear()
67
+ with open(file_path, 'r') as file:
68
+ for line in file.readlines():
69
+ role, content = line.strip().split(': ', 1)
70
+ decoded_content = content.replace('\\n', '\n')
71
+ st.session_state['messages'].append({'role': role, 'content': decoded_content})
72
+
73
+ def main():
74
+ st.title("T5 Chat Interface")
75
+
76
+ if 'messages' not in st.session_state:
77
+ st.session_state['messages'] = []
78
+ if 'show_chats' not in st.session_state:
79
+ st.session_state['show_chats'] = False
80
+
81
+ # File uploader
82
+ uploaded_files = st.file_uploader("Choose multiple files", type=["txt", "docx", "py", "java", "class", "php", "js", "css"], accept_multiple_files=True)
83
+ if uploaded_files:
84
+ for uploaded_file in uploaded_files:
85
+ file_name = uploaded_file.name
86
+ file_content = uploaded_file.read().decode("utf-8")
87
+ with st.expander(file_name):
88
+ st.write(file_content)
89
+
90
+ show_msgs()
91
+
92
+ user_input = st.chat_input("Enter your prompt:", key="1")
93
+ if user_input:
94
+ with st.chat_message("user"):
95
+ st.write(user_input)
96
+ st.session_state.messages.append({"role": "user", "content": user_input})
97
+ response = generate_response(user_input)
98
+ st.session_state.messages.append({"role": "assistant", "content": response})
99
+ with st.chat_message("assistant"):
100
+ st.write_stream(response_generator(response))
101
+
102
+ chatlog = format_messages_for_summary(st.session_state['messages'])
103
+ st.sidebar.download_button(
104
+ label="Download Chat Log",
105
+ data=chatlog,
106
+ file_name="chat_log.txt",
107
+ mime="text/plain"
108
+ )
109
+
110
+ if st.sidebar.button("Save Chat"):
111
+ save_chat()
112
+
113
+ if st.sidebar.button("New Chat"):
114
+ st.session_state['messages'].clear()
115
+
116
+ if st.sidebar.checkbox("Show/hide chat history", value=st.session_state['show_chats']):
117
+ st.sidebar.title("Previous Chats")
118
+ load_saved_chats()
119
+
120
+ if __name__ == "__main__":
121
+ main()