Baldezo313 commited on
Commit
aa3e8af
·
verified ·
1 Parent(s): 76b6b3a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +30 -32
src/streamlit_app.py CHANGED
@@ -1,45 +1,43 @@
1
  import os
2
- import streamlit as st
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
-
6
- # 🔧 Fix permissions & paths (HF Spaces & Streamlit)
7
  os.environ['HOME'] = '/tmp'
8
  os.environ['XDG_CACHE_HOME'] = '/tmp/.cache'
9
  os.environ['HF_HOME'] = '/tmp/.hf'
10
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/.hf/transformers'
11
  os.environ['STREAMLIT_HOME'] = '/tmp/.hf/streamlit'
12
- os.environ['STREAMLIT_CONFIG_FILE'] = '/tmp/.hf/streamlit/config.toml'
13
 
14
- # 🧠 Model name (change if needed)
15
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
 
16
 
17
- @st.cache_resource(show_spinner="🔄 Loading model...")
18
  def load_model():
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- MODEL_NAME,
22
- torch_dtype=torch.float32,
23
- device_map="auto"
24
- )
25
  return tokenizer, model
26
 
27
  tokenizer, model = load_model()
28
 
29
- st.title("💬 Simple LLM Chatbot (Streamlit + HF Transformers)")
30
-
31
- user_input = st.text_area("🧑‍💻 Posez votre question :", "", height=100)
32
-
33
- if st.button("Envoyer") and user_input.strip():
34
- with st.spinner("✍️ Génération en cours..."):
35
- input_ids = tokenizer.encode(user_input, return_tensors="pt")
36
- output = model.generate(
37
- input_ids,
38
- max_new_tokens=150,
39
- do_sample=True,
40
- top_k=50,
41
- top_p=0.95,
42
- temperature=0.7
43
- )
44
- response = tokenizer.decode(output[0], skip_special_tokens=True)
45
- st.markdown(f"**Réponse :**\n\n{response}")
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
2
  os.environ['HOME'] = '/tmp'
3
  os.environ['XDG_CACHE_HOME'] = '/tmp/.cache'
4
  os.environ['HF_HOME'] = '/tmp/.hf'
5
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/.hf/transformers'
6
  os.environ['STREAMLIT_HOME'] = '/tmp/.hf/streamlit'
 
7
 
8
+ import streamlit as st
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ import torch
11
 
12
+ @st.cache_resource
13
  def load_model():
14
+ model_name = "openchat/openchat-3.5-0106"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
 
 
 
17
  return tokenizer, model
18
 
19
  tokenizer, model = load_model()
20
 
21
+ st.title("OpenChat 🤖")
22
+
23
+ if "messages" not in st.session_state:
24
+ st.session_state.messages = [{"role": "assistant", "content": "Salut ! Pose-moi une question."}]
25
+
26
+ for msg in st.session_state.messages:
27
+ with st.chat_message(msg["role"]):
28
+ st.markdown(msg["content"])
29
+
30
+ query = st.chat_input("Votre message...")
31
+
32
+ if query:
33
+ st.session_state.messages.append({"role": "user", "content": query})
34
+ with st.chat_message("user"):
35
+ st.markdown(query)
36
+
37
+ inputs = tokenizer(query, return_tensors="pt").to(model.device)
38
+ outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True, top_p=0.95, top_k=50)
39
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+
41
+ st.session_state.messages.append({"role": "assistant", "content": response})
42
+ with st.chat_message("assistant"):
43
+ st.markdown(response)