odaly commited on
Commit
c9bdddd
·
verified ·
1 Parent(s): 12195e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import streamlit as st
3
- import transformers
4
  import torch
5
  import time
6
 
@@ -13,22 +13,41 @@ else:
13
  st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
14
  st.stop()
15
 
 
 
 
16
  # Initialize the model and tokenizer
17
- model_id = "gpt2" # Example model ID
18
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, token=hf_token)
19
- model = transformers.AutoModelForCausalLM.from_pretrained(model_id, token=hf_token)
 
 
 
 
 
 
 
 
20
 
21
  def generate_response(prompt):
22
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
23
- output = model.generate(input_ids, max_length=150, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95)
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  response = tokenizer.decode(output[0], skip_special_tokens=True)
25
  return response
26
 
27
- def response_generator(content):
28
- for word in content.split():
29
- yield word + " "
30
- time.sleep(0.1) # Small delay for streaming effect
31
-
32
  def save_chat():
33
  chat_dir = './Intermediate-Chats'
34
  if not os.path.exists(chat_dir):
@@ -61,6 +80,13 @@ def load_chat(file_path):
61
  role, content = line.strip().split(': ', 1)
62
  st.session_state['messages'].append({'role': role, 'content': content})
63
 
 
 
 
 
 
 
 
64
  def main():
65
  st.title("LLaMA Chat Interface")
66
 
@@ -82,10 +108,11 @@ def main():
82
 
83
  # Streaming response in the chat interface
84
  with st.chat_message("assistant"):
 
85
  full_response = ""
86
  for word in response_generator(response):
87
  full_response += word
88
- st.write(full_response) # Re-write the entire content for a streaming effect
89
 
90
  # Sidebar functionality
91
  if st.sidebar.button("Save Chat"):
 
1
  import os
2
  import streamlit as st
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import time
6
 
 
13
  st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
14
  st.stop()
15
 
16
+ # Model ID (use a valid model from Hugging Face)
17
+ model_id = "gpt2" # Replace with a valid model
18
+
19
  # Initialize the model and tokenizer
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
21
+ model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_token)
22
+
23
+ # Set pad_token_id to eos_token_id to avoid the warning
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ # Alternatively, add a new padding token if it's not defined
28
+ # if tokenizer.pad_token is None:
29
+ # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
30
+ # model.resize_token_embeddings(len(tokenizer))
31
 
32
  def generate_response(prompt):
33
+ # Tokenize the prompt with attention mask
34
+ inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
35
+
36
+ # Generate text with the attention mask
37
+ output = model.generate(
38
+ inputs['input_ids'],
39
+ attention_mask=inputs['attention_mask'], # Pass attention mask to prevent the warning
40
+ max_length=150,
41
+ num_return_sequences=1,
42
+ do_sample=True,
43
+ top_k=50,
44
+ top_p=0.95
45
+ )
46
+
47
+ # Decode the generated output
48
  response = tokenizer.decode(output[0], skip_special_tokens=True)
49
  return response
50
 
 
 
 
 
 
51
  def save_chat():
52
  chat_dir = './Intermediate-Chats'
53
  if not os.path.exists(chat_dir):
 
80
  role, content = line.strip().split(': ', 1)
81
  st.session_state['messages'].append({'role': role, 'content': content})
82
 
83
+ def response_generator(content):
84
+ current_output = ""
85
+ for word in content.split():
86
+ current_output += word + " "
87
+ yield current_output.strip()
88
+ time.sleep(0.2)
89
+
90
  def main():
91
  st.title("LLaMA Chat Interface")
92
 
 
108
 
109
  # Streaming response in the chat interface
110
  with st.chat_message("assistant"):
111
+ placeholder = st.empty()
112
  full_response = ""
113
  for word in response_generator(response):
114
  full_response += word
115
+ placeholder.write(full_response)
116
 
117
  # Sidebar functionality
118
  if st.sidebar.button("Save Chat"):