Spaces:
Sleeping
Sleeping
Leonardo Parente commited on
Commit ·
960c913
1
Parent(s): bc4906f
add logo
Browse files- app.py +50 -37
- orbgptlogo.png +0 -0
app.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 3 |
from langchain.memory import ConversationBufferMemory
|
| 4 |
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
|
| 5 |
-
from langchain.chains import
|
| 6 |
-
from langchain.prompts import PromptTemplate
|
| 7 |
from langchain.embeddings import VoyageEmbeddings
|
| 8 |
from langchain.vectorstores import SupabaseVectorStore
|
| 9 |
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
| 10 |
from st_supabase_connection import SupabaseConnection
|
| 11 |
|
| 12 |
msgs = StreamlitChatMessageHistory()
|
| 13 |
-
memory = ConversationBufferMemory(
|
|
|
|
|
|
|
| 14 |
|
| 15 |
supabase_client = st.connection(
|
| 16 |
name="orbgpt",
|
|
@@ -18,45 +21,51 @@ supabase_client = st.connection(
|
|
| 18 |
ttl=None,
|
| 19 |
)
|
| 20 |
|
| 21 |
-
embeddings = VoyageEmbeddings(model="voyage-01")
|
| 22 |
-
vector_store = SupabaseVectorStore(
|
| 23 |
-
embedding=embeddings,
|
| 24 |
-
client=supabase_client,
|
| 25 |
-
table_name="documents",
|
| 26 |
-
query_name="match_documents",
|
| 27 |
-
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
model_path = "01-ai/Yi-6B-Chat"
|
| 31 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 32 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 33 |
-
model_path,
|
| 34 |
-
device_map="auto",
|
| 35 |
-
offload_folder="offload",
|
| 36 |
-
offload_state_dict=True,
|
| 37 |
-
torch_dtype="auto",
|
| 38 |
-
).eval()
|
| 39 |
-
pipe = pipeline(
|
| 40 |
-
"text-generation",
|
| 41 |
-
model=model,
|
| 42 |
-
tokenizer=tokenizer,
|
| 43 |
-
max_new_tokens=10,
|
| 44 |
-
use_fast=False,
|
| 45 |
-
)
|
| 46 |
-
hf = HuggingFacePipeline(pipeline=pipe)
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
Answer: Let's think step by step."""
|
| 51 |
-
prompt = PromptTemplate.from_template(template)
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
st.title("🪩🤖")
|
| 60 |
|
| 61 |
if len(msgs.messages) == 0:
|
| 62 |
msgs.add_ai_message("Ask me anything about orb community projects!")
|
|
@@ -66,5 +75,9 @@ for msg in msgs.messages:
|
|
| 66 |
|
| 67 |
if prompt := st.chat_input("Ask something"):
|
| 68 |
st.chat_message("human").write(prompt)
|
| 69 |
-
|
| 70 |
-
st.chat_message("ai")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from pathlib import Path
|
| 3 |
import streamlit as st
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 5 |
from langchain.memory import ConversationBufferMemory
|
| 6 |
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
|
| 7 |
+
from langchain.chains import ConversationalRetrievalChain
|
|
|
|
| 8 |
from langchain.embeddings import VoyageEmbeddings
|
| 9 |
from langchain.vectorstores import SupabaseVectorStore
|
| 10 |
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
| 11 |
from st_supabase_connection import SupabaseConnection
|
| 12 |
|
| 13 |
msgs = StreamlitChatMessageHistory()
|
| 14 |
+
memory = ConversationBufferMemory(
|
| 15 |
+
memory_key="history", chat_memory=msgs, return_messages=True
|
| 16 |
+
)
|
| 17 |
|
| 18 |
supabase_client = st.connection(
|
| 19 |
name="orbgpt",
|
|
|
|
| 21 |
ttl=None,
|
| 22 |
)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
@st.cache_resource
|
| 26 |
+
def load_retriever():
|
| 27 |
+
# load embeddings using VoyageAI and Supabase
|
| 28 |
+
embeddings = VoyageEmbeddings(model="voyage-01")
|
| 29 |
+
vector_store = SupabaseVectorStore(
|
| 30 |
+
embedding=embeddings,
|
| 31 |
+
client=supabase_client.client,
|
| 32 |
+
table_name="documents",
|
| 33 |
+
query_name="match_documents",
|
| 34 |
+
)
|
| 35 |
+
return vector_store.as_retriever()
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
@st.cache_resource
|
| 39 |
+
def load_model():
|
| 40 |
+
model_path = "llmware/bling-falcon-1b-0.1"
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 42 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
+
model_path,
|
| 44 |
+
device_map="auto",
|
| 45 |
+
offload_folder="offload",
|
| 46 |
+
offload_state_dict=True,
|
| 47 |
+
torch_dtype="auto",
|
| 48 |
+
).eval()
|
| 49 |
+
pipe = pipeline(
|
| 50 |
+
"text-generation",
|
| 51 |
+
model=model,
|
| 52 |
+
tokenizer=tokenizer,
|
| 53 |
+
use_fast=False,
|
| 54 |
+
)
|
| 55 |
+
return HuggingFacePipeline(pipeline=pipe)
|
| 56 |
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
hf = load_model()
|
| 59 |
+
retriever = load_retriever()
|
| 60 |
+
chat = ConversationalRetrievalChain.from_llm(hf, retriever)
|
| 61 |
|
| 62 |
+
st.markdown(
|
| 63 |
+
"<div style='display: flex;justify-content: center;'><img width='150' src='data:image/png;base64,{}' class='img-fluid'></div>".format(
|
| 64 |
+
base64.b64encode(Path("orbgptlogo.png").read_bytes()).decode()
|
| 65 |
+
),
|
| 66 |
+
unsafe_allow_html=True,
|
| 67 |
+
)
|
| 68 |
|
|
|
|
| 69 |
|
| 70 |
if len(msgs.messages) == 0:
|
| 71 |
msgs.add_ai_message("Ask me anything about orb community projects!")
|
|
|
|
| 75 |
|
| 76 |
if prompt := st.chat_input("Ask something"):
|
| 77 |
st.chat_message("human").write(prompt)
|
| 78 |
+
msgs.add_user_message(prompt)
|
| 79 |
+
with st.chat_message("ai"):
|
| 80 |
+
with st.spinner("Processing your question..."):
|
| 81 |
+
response = chat({"question": prompt, "chat_history": memory.buffer})
|
| 82 |
+
msgs.add_ai_message(response["answer"])
|
| 83 |
+
st.write(response["answer"])
|
orbgptlogo.png
ADDED
|