Prajjwalng commited on
Commit
3ee9c30
·
verified ·
1 Parent(s): 490d9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py CHANGED
@@ -21,3 +21,74 @@ else:
21
  print("CUDA is not available. Using CPU.")
22
 
23
  print(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  print("CUDA is not available. Using CPU.")
22
 
23
  print(f"Using device: {device}")
24
+
25
+ @st.cache_resource
26
+ def load_model():
27
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
28
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
29
+ return tokenizer, model
30
+
31
+ tokenizer, model = load_model()
32
+
33
+ # Function to generate chatbot response
34
+ def generate_response(prompt, chat_history_ids=None):
35
+ inputs = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors="pt")
36
+
37
+ if chat_history_ids is None:
38
+ chat_history_ids = None
39
+ else:
40
+ chat_history_ids = torch.tensor(chat_history_ids)
41
+
42
+ # generate a response while limiting the total chat history to 1000 tokens,
43
+ chat_history_ids = model.generate(
44
+ inputs, max_length=1000,
45
+ pad_token_id=tokenizer.eos_token_id,
46
+ no_repeat_ngram_size=3,
47
+ temperature=0.7,
48
+ top_k=50,
49
+ top_p=0.95,
50
+ chat_history_ids = chat_history_ids
51
+ )
52
+
53
+ response = tokenizer.decode(chat_history_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
54
+ return response, chat_history_ids.tolist()
55
+
56
+ # Streamlit app
57
+ st.title("Simple Chatbot")
58
+
59
+ # Initialize chat history
60
+ if "messages" not in st.session_state:
61
+ st.session_state.messages = []
62
+ if "chat_history_ids" not in st.session_state:
63
+ st.session_state.chat_history_ids = None
64
+
65
+ # Display chat messages from history on app rerun
66
+ for message in st.session_state.messages:
67
+ with st.chat_message(message["role"]):
68
+ st.markdown(message["content"])
69
+
70
+ # Accept user input
71
+ if prompt := st.chat_input("What is up?"):
72
+ # Add user message to chat history
73
+ st.session_state.messages.append({"role": "user", "content": prompt})
74
+ # Display user message in chat message container
75
+ with st.chat_message("user"):
76
+ st.markdown(prompt)
77
+
78
+ # Generate and display chatbot response
79
+ with st.chat_message("assistant"):
80
+ message_placeholder = st.empty()
81
+ full_response = ""
82
+ response, st.session_state.chat_history_ids = generate_response(prompt, st.session_state.chat_history_ids)
83
+
84
+ # Simulate stream of responses with milliseconds delay
85
+ import time
86
+ for chunk in response.split():
87
+ full_response += chunk + " "
88
+ time.sleep(0.05)
89
+ # Add a placeholder to stream the response
90
+ message_placeholder.markdown(full_response + "▌")
91
+ message_placeholder.markdown(full_response)
92
+
93
+ # Add assistant response to chat history
94
+ st.session_state.messages.append({"role": "assistant", "content": full_response})