Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import streamlit as st
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
import time
|
| 6 |
|
|
@@ -13,22 +13,41 @@ else:
|
|
| 13 |
st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
|
| 14 |
st.stop()
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
# Initialize the model and tokenizer
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def generate_response(prompt):
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 25 |
return response
|
| 26 |
|
| 27 |
-
def response_generator(content):
|
| 28 |
-
for word in content.split():
|
| 29 |
-
yield word + " "
|
| 30 |
-
time.sleep(0.1) # Small delay for streaming effect
|
| 31 |
-
|
| 32 |
def save_chat():
|
| 33 |
chat_dir = './Intermediate-Chats'
|
| 34 |
if not os.path.exists(chat_dir):
|
|
@@ -61,6 +80,13 @@ def load_chat(file_path):
|
|
| 61 |
role, content = line.strip().split(': ', 1)
|
| 62 |
st.session_state['messages'].append({'role': role, 'content': content})
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def main():
|
| 65 |
st.title("LLaMA Chat Interface")
|
| 66 |
|
|
@@ -82,10 +108,11 @@ def main():
|
|
| 82 |
|
| 83 |
# Streaming response in the chat interface
|
| 84 |
with st.chat_message("assistant"):
|
|
|
|
| 85 |
full_response = ""
|
| 86 |
for word in response_generator(response):
|
| 87 |
full_response += word
|
| 88 |
-
|
| 89 |
|
| 90 |
# Sidebar functionality
|
| 91 |
if st.sidebar.button("Save Chat"):
|
|
|
|
| 1 |
import os
|
| 2 |
import streamlit as st
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
import torch
|
| 5 |
import time
|
| 6 |
|
|
|
|
| 13 |
st.error("Hugging Face API token not found. Please set the HUGGING_FACE_API_TOKEN environment variable.")
|
| 14 |
st.stop()
|
| 15 |
|
| 16 |
+
# Model ID (use a valid model from Hugging Face)
|
| 17 |
+
model_id = "gpt2" # Replace with a valid model
|
| 18 |
+
|
| 19 |
# Initialize the model and tokenizer
|
| 20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
|
| 21 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_token)
|
| 22 |
+
|
| 23 |
+
# Set pad_token_id to eos_token_id to avoid the warning
|
| 24 |
+
if tokenizer.pad_token is None:
|
| 25 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 26 |
+
|
| 27 |
+
# Alternatively, add a new padding token if it's not defined
|
| 28 |
+
# if tokenizer.pad_token is None:
|
| 29 |
+
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 30 |
+
# model.resize_token_embeddings(len(tokenizer))
|
| 31 |
|
| 32 |
def generate_response(prompt):
|
| 33 |
+
# Tokenize the prompt with attention mask
|
| 34 |
+
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
|
| 35 |
+
|
| 36 |
+
# Generate text with the attention mask
|
| 37 |
+
output = model.generate(
|
| 38 |
+
inputs['input_ids'],
|
| 39 |
+
attention_mask=inputs['attention_mask'], # Pass attention mask to prevent the warning
|
| 40 |
+
max_length=150,
|
| 41 |
+
num_return_sequences=1,
|
| 42 |
+
do_sample=True,
|
| 43 |
+
top_k=50,
|
| 44 |
+
top_p=0.95
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Decode the generated output
|
| 48 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 49 |
return response
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def save_chat():
|
| 52 |
chat_dir = './Intermediate-Chats'
|
| 53 |
if not os.path.exists(chat_dir):
|
|
|
|
| 80 |
role, content = line.strip().split(': ', 1)
|
| 81 |
st.session_state['messages'].append({'role': role, 'content': content})
|
| 82 |
|
| 83 |
+
def response_generator(content):
|
| 84 |
+
current_output = ""
|
| 85 |
+
for word in content.split():
|
| 86 |
+
current_output += word + " "
|
| 87 |
+
yield current_output.strip()
|
| 88 |
+
time.sleep(0.2)
|
| 89 |
+
|
| 90 |
def main():
|
| 91 |
st.title("LLaMA Chat Interface")
|
| 92 |
|
|
|
|
| 108 |
|
| 109 |
# Streaming response in the chat interface
|
| 110 |
with st.chat_message("assistant"):
|
| 111 |
+
placeholder = st.empty()
|
| 112 |
full_response = ""
|
| 113 |
for word in response_generator(response):
|
| 114 |
full_response += word
|
| 115 |
+
placeholder.write(full_response)
|
| 116 |
|
| 117 |
# Sidebar functionality
|
| 118 |
if st.sidebar.button("Save Chat"):
|