odaly commited on
Commit
c035add
·
verified ·
1 Parent(s): c9bdddd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -120
app.py CHANGED
@@ -1,129 +1,206 @@
1
  import os
2
- import streamlit as st
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
  import time
6
-
7
- # Hugging Face API Token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  hf_token = os.getenv("HUGGING_FACE_API_TOKEN")
9
-
10
- if hf_token:
11
- st.write(f"Hugging Face API token found: {hf_token[:4]}...") # Displaying only the first 4 characters for security
12
- else:
13
  st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
14
  st.stop()
15
 
16
- # Model ID (use a valid model from Hugging Face)
17
- model_id = "gpt2" # Replace with a valid model
18
-
19
- # Initialize the model and tokenizer
20
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
21
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_token)
22
-
23
- # Set pad_token_id to eos_token_id to avoid the warning
24
- if tokenizer.pad_token is None:
25
- tokenizer.pad_token = tokenizer.eos_token
26
-
27
- # Alternatively, add a new padding token if it's not defined
28
- # if tokenizer.pad_token is None:
29
- # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
30
- # model.resize_token_embeddings(len(tokenizer))
31
-
32
- def generate_response(prompt):
33
- # Tokenize the prompt with attention mask
34
- inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Generate text with the attention mask
37
- output = model.generate(
38
- inputs['input_ids'],
39
- attention_mask=inputs['attention_mask'], # Pass attention mask to prevent the warning
40
- max_length=150,
41
- num_return_sequences=1,
42
- do_sample=True,
43
- top_k=50,
44
- top_p=0.95
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Decode the generated output
48
- response = tokenizer.decode(output[0], skip_special_tokens=True)
49
- return response
50
-
51
- def save_chat():
52
- chat_dir = './Intermediate-Chats'
53
- if not os.path.exists(chat_dir):
54
- os.makedirs(chat_dir)
55
- if st.session_state['messages']:
56
- filename = f'{chat_dir}/chat_{int(time.time())}.txt'
57
- with open(filename, 'w') as f:
58
- for message in st.session_state['messages']:
59
- f.write(f"{message['role']}: {message['content']}\n")
60
- st.session_state['messages'].clear()
61
- st.success("Chat saved successfully.")
62
  else:
63
- st.warning("No chat messages to save.")
64
-
65
- def load_saved_chats():
66
- chat_dir = './Intermediate-Chats'
67
- if os.path.exists(chat_dir):
68
- files = os.listdir(chat_dir)
69
- files.sort(key=lambda x: os.path.getmtime(os.path.join(chat_dir, x)), reverse=True)
70
- for file_name in files:
71
- display_name = file_name[:-4] if file_name.endswith('.txt') else file_name
72
- if st.sidebar.button(display_name):
73
- load_chat(os.path.join(chat_dir, file_name))
74
-
75
- def load_chat(file_path):
76
- st.session_state['messages'].clear()
77
- with open(file_path, 'r') as file:
78
- for line in file:
79
- if ': ' in line:
80
- role, content = line.strip().split(': ', 1)
81
- st.session_state['messages'].append({'role': role, 'content': content})
82
-
83
- def response_generator(content):
84
- current_output = ""
85
- for word in content.split():
86
- current_output += word + " "
87
- yield current_output.strip()
88
- time.sleep(0.2)
89
-
90
- def main():
91
- st.title("LLaMA Chat Interface")
92
-
93
- if 'messages' not in st.session_state:
94
- st.session_state['messages'] = []
95
-
96
- # Display chat messages
97
- for msg in st.session_state.messages:
98
- role = msg['role']
99
- with st.chat_message(role):
100
- st.write(msg['content'])
101
-
102
- # Accept user input
103
- user_input = st.chat_input("Enter your prompt:")
104
- if user_input:
105
- st.session_state.messages.append({"role": "user", "content": user_input})
106
- response = generate_response(user_input)
107
- st.session_state.messages.append({"role": "assistant", "content": response})
108
-
109
- # Streaming response in the chat interface
110
- with st.chat_message("assistant"):
111
- placeholder = st.empty()
112
- full_response = ""
113
- for word in response_generator(response):
114
- full_response += word
115
- placeholder.write(full_response)
116
-
117
- # Sidebar functionality
118
- if st.sidebar.button("Save Chat"):
119
- save_chat()
120
-
121
- if st.sidebar.button("New Chat"):
122
- st.session_state['messages'].clear()
123
-
124
- if st.sidebar.checkbox("Show/hide chat history"):
125
- st.sidebar.title("Previous Chats")
126
- load_saved_chats()
127
-
128
- if __name__ == "__main__":
129
- main()
 
 
 
1
  import os
 
 
 
2
  import time
3
+ import re
4
+ import requests
5
+ import json
6
+ from bs4 import BeautifulSoup
7
+ import streamlit as st
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, LlamaConfig
9
+ from streamlit_chat import message
10
+
11
+ # Set page title and icon
12
+ st.set_page_config(page_title="LLaMA Chatbot", page_icon=":robot_face:")
13
+
14
+ # Custom CSS for styling
15
+ st.markdown(
16
+ """
17
+ <style>
18
+ .stButton>button {
19
+ background-color: #4CAF50;
20
+ color: white;
21
+ border-radius: 12px;
22
+ padding: 10px 24px;
23
+ }
24
+ .stTextArea textarea {
25
+ background-color: #f5f5f5;
26
+ }
27
+ .stDownloadButton>button {
28
+ background-color: #4CAF50;
29
+ color: white;
30
+ }
31
+ </style>
32
+ """, unsafe_allow_html=True
33
+ )
34
+
35
+ # Load Hugging Face API token
36
  hf_token = os.getenv("HUGGING_FACE_API_TOKEN")
37
+ if not hf_token:
 
 
 
38
  st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
39
  st.stop()
40
 
41
+ # Initialize session state variables
42
+ if 'generated' not in st.session_state:
43
+ st.session_state['generated'] = []
44
+ if 'past' not in st.session_state:
45
+ st.session_state['past'] = []
46
+ if 'messages' not in st.session_state:
47
+ st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
48
+ if 'model_name' not in st.session_state:
49
+ st.session_state['model_name'] = []
50
+ if 'total_tokens' not in st.session_state:
51
+ st.session_state['total_tokens'] = []
52
+ if 'total_cost' not in st.session_state:
53
+ st.session_state['total_cost'] = 0.0
54
+ if 'chat_data' not in st.session_state:
55
+ st.session_state['chat_data'] = [] # For storing the chat logs
56
+
57
+ # Sidebar - Model Selection, Style Parameters, and Cost Display
58
+ st.sidebar.title("Model Selection")
59
+ model_name = st.sidebar.selectbox("Choose a model:", ["gpt2", "gpt-neo-125M", "distilgpt2", "LLaMA"])
60
+
61
+ # Parameters to adjust the response style and creativity
62
+ st.sidebar.title("Response Style Controls")
63
+ temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
64
+ top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.5, max_value=1.0, value=0.9, step=0.05)
65
+ top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
66
+ repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
67
+ max_length = st.sidebar.slider("Max Length", min_value=50, max_value=4859, value=500, step=10)
68
+
69
+ # Function to load the model and tokenizer
70
+ @st.cache_resource
71
+ def load_model_and_tokenizer(model_name):
72
+ if "LLaMA" in model_name:
73
+ tokenizer = LlamaTokenizer.from_pretrained(model_name, token=hf_token)
74
+ config = LlamaConfig.from_pretrained(model_name, token=hf_token)
75
+ model = LlamaForCausalLM.from_pretrained(model_name, config=config, token=hf_token)
76
+ else:
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
78
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token)
79
 
80
+ return tokenizer, model
81
+
82
+ tokenizer, model = load_model_and_tokenizer(model_name)
83
+
84
+ # Function to reset the session
85
+ def reset_session():
86
+ st.session_state['generated'] = []
87
+ st.session_state['past'] = []
88
+ st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
89
+ st.session_state['model_name'] = []
90
+ st.session_state['total_tokens'] = []
91
+ st.session_state['total_cost'] = 0.0
92
+ st.session_state['chat_data'] = [] # Reset chat logs
93
+
94
+ # Reset chat button in sidebar
95
+ reset_button = st.sidebar.button("Reset Chat")
96
+ if reset_button:
97
+ reset_session()
98
+
99
+ # Function to fetch and parse a webpage for specific tags
100
+ def fetch_website_content(url):
101
+ try:
102
+ response = requests.get(url)
103
+ if response.status_code == 200:
104
+ soup = BeautifulSoup(response.text, 'html.parser')
105
+ headings = [h.get_text() for h in soup.find_all(['h1', 'h2', 'h3'])]
106
+ paragraphs = [p.get_text() for p in soup.find_all('p')]
107
+ articles = [article.get_text() for article in soup.find_all('article')]
108
+
109
+ content = {
110
+ "headings": headings,
111
+ "paragraphs": paragraphs,
112
+ "articles": articles
113
+ }
114
+ return content
115
+ else:
116
+ return {"error": f"Failed to retrieve content, status code: {response.status_code}"}
117
+ except Exception as e:
118
+ return {"error": f"An error occurred: {str(e)}"}
119
+
120
+ # Function to check if the input contains a URL
121
+ def extract_url_from_text(text):
122
+ url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
123
+ urls = re.findall(url_pattern, text)
124
+ return urls
125
+
126
+ # Function to generate a response using the model with adjustable parameters
127
+ def generate_response(prompt):
128
+ urls = extract_url_from_text(prompt)
129
 
130
+ if urls:
131
+ # If a URL is detected, crawl the webpage and extract content
132
+ url_content = fetch_website_content(urls[0]) # Crawl only the first URL for simplicity
133
+ if 'error' in url_content:
134
+ return url_content['error']
135
+ else:
136
+ return f"Headings: {url_content['headings']}\n\nParagraphs: {url_content['paragraphs']}\n\nArticles: {url_content['articles']}"
 
 
 
 
 
 
 
 
137
  else:
138
+ # If no URL, proceed with generating a response from the model
139
+ inputs = tokenizer(prompt, return_tensors="pt")
140
+
141
+ # Pass attention_mask and set pad_token_id
142
+ outputs = model.generate(
143
+ inputs.input_ids,
144
+ attention_mask=inputs.attention_mask,
145
+ max_length=max_length,
146
+ do_sample=True,
147
+ temperature=temperature,
148
+ top_p=top_p,
149
+ top_k=top_k,
150
+ repetition_penalty=repetition_penalty,
151
+ pad_token_id=tokenizer.eos_token_id # Set pad_token_id
152
+ )
153
+
154
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
155
+ return response
156
+
157
+ # Function to save chat logs for later fine-tuning
158
+ def save_chat_data(chat_data):
159
+ with open('chat_data.json', 'w') as f:
160
+ json.dump(chat_data, f, indent=4)
161
+
162
+ # Containers for chat history and user input
163
+ response_container = st.container()
164
+ container = st.container()
165
+
166
+ with container:
167
+ with st.form(key='user_input_form'):
168
+ user_input = st.text_area("You:", key='user_input', height=100)
169
+ submit_button = st.form_submit_button("Send")
170
+
171
+ if submit_button and user_input:
172
+ start_time = time.time()
173
+ output = generate_response(user_input)
174
+ end_time = time.time()
175
+ inference_time = end_time - start_time
176
+
177
+ # Append user input and model output to session state
178
+ st.session_state['past'].append(user_input)
179
+ st.session_state['generated'].append(output)
180
+ st.session_state['model_name'].append(model_name)
181
+
182
+ # Log chat data for future training
183
+ st.session_state['chat_data'].append({
184
+ "user_input": user_input,
185
+ "model_response": output
186
+ })
187
+
188
+ # Save chat data to a file (this could be used later for training)
189
+ save_chat_data(st.session_state['chat_data'])
190
+
191
+ # Calculate tokens and cost
192
+ num_tokens = len(tokenizer.encode(user_input)) + len(tokenizer.encode(output))
193
+ st.session_state['total_tokens'].append(num_tokens)
194
+ cost_per_1000_tokens = 0.0001
195
+ cost = cost_per_1000_tokens * (num_tokens / 1000)
196
+ st.session_state['total_cost'] += cost
197
+
198
+ # Display chat history
199
+ with response_container:
200
+ for i in range(len(st.session_state['generated'])):
201
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
202
+ message(st.session_state['generated'][i], key=str(i))
203
+ st.write(f"Model: {st.session_state['model_name'][i]}")
204
+ st.write(f"Tokens: {st.session_state['total_tokens'][i]}")
205
+ st.write(f"Inference Time: {inference_time:.4f} seconds")
206
+ st.write(f"Cost: ${cost:.5f}")