Prajjwalng commited on
Commit
b8df990
·
verified ·
1 Parent(s): 9879c04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -38
app.py CHANGED
@@ -1,43 +1,15 @@
1
-
2
- import streamlit as st
3
- import os
4
- hf_token = os.environ.get("HF_TOKEN")
5
- if hf_token:
6
- # Use the token
7
- from huggingface_hub import login
8
- login(token = hf_token)
9
- #your code that requires the token.
10
- else:
11
- print("HF_TOKEN environment variable not set.")
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
13
- import torch
14
- if torch.cuda.is_available():
15
- device = torch.device("cuda")
16
- print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}") #prints GPU name
17
- print(f"Number of GPUs available: {torch.cuda.device_count()}") #prints number of gpus.
18
- print(f"Current GPU device: {torch.cuda.current_device()}")#prints current gpu id.
19
- else:
20
- device = torch.device("cpu")
21
- print("CUDA is not available. Using CPU.")
22
-
23
- print(f"Using device: {device}")
24
-
25
  @st.cache_resource
26
  def load_model():
27
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
28
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
29
  return tokenizer, model
30
 
31
  tokenizer, model = load_model()
32
 
33
  # Function to generate chatbot response
34
- def generate_response(prompt, chat_history_ids=None):
35
- inputs = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors="pt")
36
-
37
- if chat_history_ids is None:
38
- chat_history_ids = None
39
- else:
40
- chat_history_ids = torch.tensor(chat_history_ids)
41
 
42
  # generate a response while limiting the total chat history to 1000 tokens,
43
  chat_history_ids = model.generate(
@@ -47,11 +19,10 @@ def generate_response(prompt, chat_history_ids=None):
47
  temperature=0.7,
48
  top_k=50,
49
  top_p=0.95,
50
- chat_history_ids = chat_history_ids
51
  )
52
 
53
  response = tokenizer.decode(chat_history_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
54
- return response, chat_history_ids.tolist()
55
 
56
  # Streamlit app
57
  st.title("Simple Chatbot")
@@ -59,8 +30,8 @@ st.title("Simple Chatbot")
59
  # Initialize chat history
60
  if "messages" not in st.session_state:
61
  st.session_state.messages = []
62
- if "chat_history_ids" not in st.session_state:
63
- st.session_state.chat_history_ids = None
64
 
65
  # Display chat messages from history on app rerun
66
  for message in st.session_state.messages:
@@ -79,7 +50,7 @@ if prompt := st.chat_input("What is up?"):
79
  with st.chat_message("assistant"):
80
  message_placeholder = st.empty()
81
  full_response = ""
82
- response, st.session_state.chat_history_ids = generate_response(prompt, st.session_state.chat_history_ids)
83
 
84
  # Simulate stream of responses with milliseconds delay
85
  import time
@@ -92,3 +63,5 @@ if prompt := st.chat_input("What is up?"):
92
 
93
  # Add assistant response to chat history
94
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
 
 
1
+ # Initialize model and tokenizer (load only once)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  @st.cache_resource
3
  def load_model():
4
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
5
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
6
  return tokenizer, model
7
 
8
  tokenizer, model = load_model()
9
 
10
  # Function to generate chatbot response
11
+ def generate_response(prompt, chat_history=""):
12
+ inputs = tokenizer.encode(chat_history + prompt + tokenizer.eos_token, return_tensors="pt")
 
 
 
 
 
13
 
14
  # generate a response while limiting the total chat history to 1000 tokens,
15
  chat_history_ids = model.generate(
 
19
  temperature=0.7,
20
  top_k=50,
21
  top_p=0.95,
 
22
  )
23
 
24
  response = tokenizer.decode(chat_history_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
25
+ return response
26
 
27
  # Streamlit app
28
  st.title("Simple Chatbot")
 
30
  # Initialize chat history
31
  if "messages" not in st.session_state:
32
  st.session_state.messages = []
33
+ if "chat_history" not in st.session_state:
34
+ st.session_state.chat_history = ""
35
 
36
  # Display chat messages from history on app rerun
37
  for message in st.session_state.messages:
 
50
  with st.chat_message("assistant"):
51
  message_placeholder = st.empty()
52
  full_response = ""
53
+ response = generate_response(prompt, st.session_state.chat_history)
54
 
55
  # Simulate stream of responses with milliseconds delay
56
  import time
 
63
 
64
  # Add assistant response to chat history
65
  st.session_state.messages.append({"role": "assistant", "content": full_response})
66
+ #update the chat history.
67
+ st.session_state.chat_history += prompt + tokenizer.eos_token + response + tokenizer.eos_token