Shabbir-Anjum commited on
Commit
03914f2
Β·
verified Β·
1 Parent(s): 1d614bc

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +3 -6
src/streamlit_app.py CHANGED
@@ -3,12 +3,12 @@ import streamlit as st
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
- # βœ… Force cache dir to /tmp (writeable in Hugging Face Spaces)
7
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
8
 
9
  @st.cache_resource
10
  def load_model():
11
- model_name = "microsoft/DialoGPT-medium"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
13
  model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
14
  return tokenizer, model
@@ -16,9 +16,8 @@ def load_model():
16
  tokenizer, model = load_model()
17
 
18
  st.set_page_config(page_title="Chatbot πŸ€–", page_icon="πŸ’¬", layout="centered")
19
- st.title("πŸ€– Hugging Face Chatbot with Transformers + Streamlit")
20
 
21
- # Session state
22
  if "chat_history_ids" not in st.session_state:
23
  st.session_state.chat_history_ids = None
24
  if "past_inputs" not in st.session_state:
@@ -26,7 +25,6 @@ if "past_inputs" not in st.session_state:
26
  if "generated_responses" not in st.session_state:
27
  st.session_state.generated_responses = []
28
 
29
- # User input
30
  user_input = st.text_input("You: ", "", key="input")
31
 
32
  if st.button("Send") and user_input:
@@ -51,7 +49,6 @@ if st.button("Send") and user_input:
51
  st.session_state.past_inputs.append(user_input)
52
  st.session_state.generated_responses.append(bot_output)
53
 
54
- # Display conversation
55
  if st.session_state.generated_responses:
56
  for i in range(len(st.session_state.generated_responses)):
57
  st.markdown(f"**You:** {st.session_state.past_inputs[i]}")
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
+ # βœ… Use a cache directory that Spaces allows
7
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
8
 
9
  @st.cache_resource
10
  def load_model():
11
+ model_name = "microsoft/DialoGPT-small" # switched from -medium
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
13
  model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
14
  return tokenizer, model
 
16
  tokenizer, model = load_model()
17
 
18
  st.set_page_config(page_title="Chatbot πŸ€–", page_icon="πŸ’¬", layout="centered")
19
+ st.title("πŸ€– Hugging Face Chatbot (DialoGPT-small)")
20
 
 
21
  if "chat_history_ids" not in st.session_state:
22
  st.session_state.chat_history_ids = None
23
  if "past_inputs" not in st.session_state:
 
25
  if "generated_responses" not in st.session_state:
26
  st.session_state.generated_responses = []
27
 
 
28
  user_input = st.text_input("You: ", "", key="input")
29
 
30
  if st.button("Send") and user_input:
 
49
  st.session_state.past_inputs.append(user_input)
50
  st.session_state.generated_responses.append(bot_output)
51
 
 
52
  if st.session_state.generated_responses:
53
  for i in range(len(st.session_state.generated_responses)):
54
  st.markdown(f"**You:** {st.session_state.past_inputs[i]}")