Baldezo313 commited on
Commit
6f4b7c5
·
verified ·
1 Parent(s): 6e14416

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +11 -16
src/streamlit_app.py CHANGED
@@ -1,10 +1,8 @@
1
  import os
2
-
3
- # Rediriger les caches vers un dossier accessible en écriture
4
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache/huggingface/transformers'
5
- os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
6
- os.environ['XDG_CACHE_HOME'] = '/tmp/.cache'
7
- os.environ['STREAMLIT_HOME'] = '/tmp/.streamlit'
8
 
9
  import streamlit as st
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -14,12 +12,12 @@ import torch
14
  def load_model():
15
  model_name = "openchat/openchat-3.5-0106"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
18
  return tokenizer, model
19
 
20
  tokenizer, model = load_model()
21
 
22
- st.title("OpenChat 3.5 Demo")
23
 
24
  if "messages" not in st.session_state:
25
  st.session_state.messages = [{"role": "assistant", "content": "Posez-moi une question !"}]
@@ -28,20 +26,17 @@ for message in st.session_state.messages:
28
  with st.chat_message(message["role"]):
29
  st.markdown(message["content"])
30
 
31
- def generate_response(prompt):
32
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
- outputs = model.generate(**inputs, max_new_tokens=150)
34
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- return response
36
-
37
- query = st.chat_input("Votre question ici...")
38
 
39
  if query:
40
  st.session_state.messages.append({"role": "user", "content": query})
41
  with st.chat_message("user"):
42
  st.markdown(query)
43
 
44
- response = generate_response(query)
 
 
 
45
  st.session_state.messages.append({"role": "assistant", "content": response})
46
  with st.chat_message("assistant"):
47
  st.markdown(response)
 
1
  import os
2
+ os.environ['HF_HOME'] = '/tmp/.hf'
3
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.hf/transformers'
4
+ os.environ['XDG_CACHE_HOME'] = '/tmp/.hf/cache'
5
+ os.environ['STREAMLIT_HOME'] = '/tmp/.hf/streamlit'
 
 
6
 
7
  import streamlit as st
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
12
  def load_model():
13
  model_name = "openchat/openchat-3.5-0106"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
16
  return tokenizer, model
17
 
18
  tokenizer, model = load_model()
19
 
20
+ st.title("OpenChat - Demo")
21
 
22
  if "messages" not in st.session_state:
23
  st.session_state.messages = [{"role": "assistant", "content": "Posez-moi une question !"}]
 
26
  with st.chat_message(message["role"]):
27
  st.markdown(message["content"])
28
 
29
+ query = st.chat_input("Votre message...")
 
 
 
 
 
 
30
 
31
  if query:
32
  st.session_state.messages.append({"role": "user", "content": query})
33
  with st.chat_message("user"):
34
  st.markdown(query)
35
 
36
+ inputs = tokenizer(query, return_tensors="pt").to(model.device)
37
+ outputs = model.generate(**inputs, max_new_tokens=150)
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+
40
  st.session_state.messages.append({"role": "assistant", "content": response})
41
  with st.chat_message("assistant"):
42
  st.markdown(response)