Baldezo313 commited on
Commit
bb5f8b2
·
verified ·
1 Parent(s): 6f4b7c5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +8 -7
src/streamlit_app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
 
 
2
  os.environ['HF_HOME'] = '/tmp/.hf'
3
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/.hf/transformers'
4
- os.environ['XDG_CACHE_HOME'] = '/tmp/.hf/cache'
5
  os.environ['STREAMLIT_HOME'] = '/tmp/.hf/streamlit'
6
 
7
  import streamlit as st
@@ -17,14 +18,14 @@ def load_model():
17
 
18
  tokenizer, model = load_model()
19
 
20
- st.title("OpenChat - Demo")
21
 
22
  if "messages" not in st.session_state:
23
- st.session_state.messages = [{"role": "assistant", "content": "Posez-moi une question !"}]
24
 
25
- for message in st.session_state.messages:
26
- with st.chat_message(message["role"]):
27
- st.markdown(message["content"])
28
 
29
  query = st.chat_input("Votre message...")
30
 
@@ -34,7 +35,7 @@ if query:
34
  st.markdown(query)
35
 
36
  inputs = tokenizer(query, return_tensors="pt").to(model.device)
37
- outputs = model.generate(**inputs, max_new_tokens=150)
38
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
  st.session_state.messages.append({"role": "assistant", "content": response})
 
1
  import os
2
+ os.environ['HOME'] = '/tmp'
3
+ os.environ['XDG_CACHE_HOME'] = '/tmp/.cache'
4
  os.environ['HF_HOME'] = '/tmp/.hf'
5
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/.hf/transformers'
 
6
  os.environ['STREAMLIT_HOME'] = '/tmp/.hf/streamlit'
7
 
8
  import streamlit as st
 
18
 
19
  tokenizer, model = load_model()
20
 
21
+ st.title("OpenChat 🤖")
22
 
23
  if "messages" not in st.session_state:
24
+ st.session_state.messages = [{"role": "assistant", "content": "Salut ! Pose-moi une question."}]
25
 
26
+ for msg in st.session_state.messages:
27
+ with st.chat_message(msg["role"]):
28
+ st.markdown(msg["content"])
29
 
30
  query = st.chat_input("Votre message...")
31
 
 
35
  st.markdown(query)
36
 
37
  inputs = tokenizer(query, return_tensors="pt").to(model.device)
38
+ outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True, top_p=0.95, top_k=50)
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
41
  st.session_state.messages.append({"role": "assistant", "content": response})