Spaces:
Runtime error
Runtime error
| import os | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import streamlit as st | |
| from app_config import SYSTEM_PROMPT,MODEL,MAX_TOKENS,TRANSFORMER_MODEL | |
| from langchain.memory import ConversationSummaryBufferMemory | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from streamlit_pdf_viewer import pdf_viewer | |
| from pydantic import BaseModel | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_community.vectorstores import FAISS | |
| from sentence_transformers import SentenceTransformer | |
| from typing import Any | |
| st.title("Hitachi Support Bot") | |
| class Element(BaseModel): | |
| type: str | |
| text: Any | |
| # llm = ChatGoogleGenerativeAI( | |
| # model=MODEL, | |
| # max_tokens=MAX_TOKENS | |
| # ) | |
| llm = ChatGroq(model=MODEL,api_key=os.getenv('API_KEY')) | |
| prompt = ChatPromptTemplate.from_template(SYSTEM_PROMPT) | |
| qa_chain = LLMChain(llm=llm,prompt=prompt) | |
| embeddings = HuggingFaceEmbeddings(model_name=TRANSFORMER_MODEL) | |
| db = FAISS.load_local("faiss_index",embeddings,allow_dangerous_deserialization=True) | |
| st.markdown( | |
| """ | |
| <style> | |
| .st-emotion-cache-janbn0 { | |
| flex-direction: row-reverse; | |
| text-align: right; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def response_generator(question): | |
| relevant_docs = db.similarity_search_with_relevance_scores(question,k=5) | |
| context = "" | |
| relevant_images = [] | |
| for d,score in relevant_docs: | |
| if score > 0: | |
| if d.metadata['type'] == 'text': | |
| context += str(d.metadata['original_content']) | |
| elif d.metadata['type'] == 'table': | |
| context += str(d.metadata['original_content']) | |
| elif d.metadata['type'] == 'image': | |
| context += d.page_content | |
| relevant_images.append(d.metadata['original_content']) | |
| result = qa_chain.run({'context':context,"question":question}) | |
| return result,relevant_images | |
| with st.sidebar: | |
| st.header("Hitachi Support Bot") | |
| button = st.toggle("View Doc file.") | |
| if button: | |
| pdf_viewer("GPT OUTPUT.pdf") | |
| else: | |
| if "messages" not in st.session_state: | |
| st.session_state.messages=[{"role": "system", "content": SYSTEM_PROMPT}] | |
| if "llm" not in st.session_state: | |
| st.session_state.llm = llm | |
| if "rag_memory" not in st.session_state: | |
| st.session_state.rag_memory = ConversationSummaryBufferMemory(llm=st.session_state.llm, max_token_limit= 5000) | |
| container = st.container(height=700) | |
| for message in st.session_state.messages: | |
| if message["role"] != "system": | |
| if message["role"] == "user": | |
| with container.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| if message["role"] == "assistant": | |
| with container.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| for i in range(len(message["images"])): | |
| st.image(Image.open(BytesIO(base64.b64decode(message["images"][i].encode('utf-8'))))) | |
| if prompt := st.chat_input("Enter your query here... "): | |
| with container.chat_message("user"): | |
| st.write(prompt) | |
| st.session_state.messages.append({"role":"user" , "content":prompt}) | |
| with container.chat_message("assistant"): | |
| response,images = response_generator(prompt) | |
| st.write(response) | |
| for i in range(len(images)): | |
| st.markdown("""---""") | |
| st.image(Image.open(BytesIO(base64.b64decode(images[i].encode('utf-8'))))) | |
| st.markdown("""---""") | |
| st.session_state.rag_memory.save_context({'input': prompt}, {'output': response}) | |
| st.session_state.messages.append({"role":"assistant" , "content":response,'images':images}) |