MatheusHRV commited on
Commit
3e2c5d9
·
verified ·
1 Parent(s): 0b211f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from langchain.chat_models import HuggingFaceHub
3
  from langchain.schema import AIMessage, HumanMessage, SystemMessage
4
 
5
  # ------------------------
@@ -14,21 +14,24 @@ if "sessionMessages" not in st.session_state:
14
  ]
15
 
16
  # ------------------------
17
- # Load Hugging Face Model
18
  # ------------------------
19
- chat = HuggingFaceHub(
20
- repo_id="OpenAssistant/oasst-sft-4-pythia-12b",
21
- model_kwargs={"temperature": 0, "max_new_tokens": 512}
 
 
 
 
22
  )
23
 
24
  # ------------------------
25
  # Helper Functions
26
  # ------------------------
27
  def load_answer(question):
28
- # Add human message
29
  st.session_state.sessionMessages.append(HumanMessage(content=question))
30
 
31
- # Convert messages to a single prompt string
32
  prompt = ""
33
  for msg in st.session_state.sessionMessages:
34
  if isinstance(msg, SystemMessage):
@@ -38,13 +41,13 @@ def load_answer(question):
38
  elif isinstance(msg, AIMessage):
39
  prompt += f"AI: {msg.content}\n"
40
 
41
- # Get AI response
42
- assistant_answer = chat(prompt)
 
43
 
44
- # Add AI message to session
45
- st.session_state.sessionMessages.append(AIMessage(content=assistant_answer.content))
46
 
47
- return assistant_answer.content
48
 
49
  def get_text():
50
  return st.text_input("You: ", key="input")
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  from langchain.schema import AIMessage, HumanMessage, SystemMessage
4
 
5
  # ------------------------
 
14
  ]
15
 
16
  # ------------------------
17
+ # Load Hugging Face pipeline
18
  # ------------------------
19
+ # Using a free, instruction-tuned model
20
+ generator = pipeline(
21
+ "text-generation",
22
+ model="OpenAssistant/oasst-sft-4-pythia-12b",
23
+ device=0, # GPU if available
24
+ max_new_tokens=512,
25
+ temperature=0
26
  )
27
 
28
  # ------------------------
29
  # Helper Functions
30
  # ------------------------
31
  def load_answer(question):
 
32
  st.session_state.sessionMessages.append(HumanMessage(content=question))
33
 
34
+ # Convert session messages to a single string prompt
35
  prompt = ""
36
  for msg in st.session_state.sessionMessages:
37
  if isinstance(msg, SystemMessage):
 
41
  elif isinstance(msg, AIMessage):
42
  prompt += f"AI: {msg.content}\n"
43
 
44
+ # Generate response
45
+ output = generator(prompt, max_new_tokens=512, do_sample=True, temperature=0)
46
+ answer_text = output[0]["generated_text"][len(prompt):].strip() # remove prompt from output
47
 
48
+ st.session_state.sessionMessages.append(AIMessage(content=answer_text))
 
49
 
50
+ return answer_text
51
 
52
  def get_text():
53
  return st.text_input("You: ", key="input")