MatheusHRV commited on
Commit
e51057e
·
verified ·
1 Parent(s): 3e2c5d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -18
app.py CHANGED
@@ -2,9 +2,6 @@ import streamlit as st
2
  from transformers import pipeline
3
  from langchain.schema import AIMessage, HumanMessage, SystemMessage
4
 
5
- # ------------------------
6
- # Streamlit UI
7
- # ------------------------
8
  st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
9
  st.header("MHRV Chatbot")
10
 
@@ -13,25 +10,18 @@ if "sessionMessages" not in st.session_state:
13
  SystemMessage(content="You are a customer support chatbot for a website.")
14
  ]
15
 
16
- # ------------------------
17
- # Load Hugging Face pipeline
18
- # ------------------------
19
- # Using a free, instruction-tuned model
20
  generator = pipeline(
21
  "text-generation",
22
  model="OpenAssistant/oasst-sft-4-pythia-12b",
23
- device=0, # GPU if available
24
  max_new_tokens=512,
25
  temperature=0
26
  )
27
 
28
- # ------------------------
29
- # Helper Functions
30
- # ------------------------
31
  def load_answer(question):
32
  st.session_state.sessionMessages.append(HumanMessage(content=question))
33
 
34
- # Convert session messages to a single string prompt
35
  prompt = ""
36
  for msg in st.session_state.sessionMessages:
37
  if isinstance(msg, SystemMessage):
@@ -41,20 +31,15 @@ def load_answer(question):
41
  elif isinstance(msg, AIMessage):
42
  prompt += f"AI: {msg.content}\n"
43
 
44
- # Generate response
45
  output = generator(prompt, max_new_tokens=512, do_sample=True, temperature=0)
46
- answer_text = output[0]["generated_text"][len(prompt):].strip() # remove prompt from output
47
 
48
  st.session_state.sessionMessages.append(AIMessage(content=answer_text))
49
-
50
  return answer_text
51
 
52
  def get_text():
53
  return st.text_input("You: ", key="input")
54
 
55
- # ------------------------
56
- # Main App
57
- # ------------------------
58
  user_input = get_text()
59
  submit = st.button("Generate")
60
 
 
2
  from transformers import pipeline
3
  from langchain.schema import AIMessage, HumanMessage, SystemMessage
4
 
 
 
 
5
  st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
6
  st.header("MHRV Chatbot")
7
 
 
10
  SystemMessage(content="You are a customer support chatbot for a website.")
11
  ]
12
 
13
+ # Hugging Face model using transformers pipeline
 
 
 
14
  generator = pipeline(
15
  "text-generation",
16
  model="OpenAssistant/oasst-sft-4-pythia-12b",
17
+ device=0, # use GPU if available
18
  max_new_tokens=512,
19
  temperature=0
20
  )
21
 
 
 
 
22
  def load_answer(question):
23
  st.session_state.sessionMessages.append(HumanMessage(content=question))
24
 
 
25
  prompt = ""
26
  for msg in st.session_state.sessionMessages:
27
  if isinstance(msg, SystemMessage):
 
31
  elif isinstance(msg, AIMessage):
32
  prompt += f"AI: {msg.content}\n"
33
 
 
34
  output = generator(prompt, max_new_tokens=512, do_sample=True, temperature=0)
35
+ answer_text = output[0]["generated_text"][len(prompt):].strip()
36
 
37
  st.session_state.sessionMessages.append(AIMessage(content=answer_text))
 
38
  return answer_text
39
 
40
  def get_text():
41
  return st.text_input("You: ", key="input")
42
 
 
 
 
43
  user_input = get_text()
44
  submit = st.button("Generate")
45