| import os |
| import streamlit as st |
| import tempfile |
| import pandas as pd |
| import json |
| import xml.etree.ElementTree as ET |
| import yaml |
| from dotenv import load_dotenv |
| from bs4 import BeautifulSoup |
| from pptx import Presentation |
| from docx import Document |
|
|
| from langchain_community.document_loaders import PyPDFLoader, TextLoader |
| from langchain_community.vectorstores import FAISS |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain.chains.retrieval_qa.base import RetrievalQA |
| from langchain.chains.llm import LLMChain |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain.prompts import PromptTemplate |
|
|
| from ibm_watson_machine_learning.foundation_models import Model |
| from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM |
| from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams |
| from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods |
|
|
| |
| index = None |
| rag_chain = None |
|
|
| |
| class DocxLoader: |
| def __init__(self, file_path): |
| self.file_path = file_path |
|
|
| def load(self): |
| document = Document(self.file_path) |
| text_content = [para.text for para in document.paragraphs] |
| return " ".join(text_content) |
|
|
| |
| class PptxLoader: |
| def __init__(self, file_path): |
| self.file_path = file_path |
|
|
| def load(self): |
| presentation = Presentation(self.file_path) |
| text_content = [ |
| shape.text |
| for slide in presentation.slides |
| for shape in slide.shapes |
| if hasattr(shape, "text") |
| ] |
| return " ".join(text_content) |
|
|
| |
| def load_csv(file_path): |
| df = pd.read_csv(file_path) |
| page_size = 100 |
| page_number = st.number_input( |
| "Page number", |
| min_value=1, |
| max_value=(len(df) // page_size) + 1, |
| step=1, |
| value=1, |
| ) |
| start_index = (page_number - 1) * page_size |
| end_index = start_index + page_size |
| st.dataframe(df.iloc[start_index:end_index]) |
| return df.to_string(index=False) |
|
|
| def load_json(file_path): |
| with open(file_path, "r") as file: |
| data = json.load(file) |
| return json.dumps(data, indent=2) |
|
|
| def load_xml(file_path): |
| tree = ET.parse(file_path) |
| root = tree.getroot() |
| return ET.tostring(root, encoding="unicode") |
|
|
| def load_yaml(file_path): |
| with open(file_path, "r") as file: |
| data = yaml.safe_load(file) |
| return yaml.dump(data) |
|
|
| def load_html(file_path): |
| with open(file_path, "r", encoding="utf-8") as file: |
| soup = BeautifulSoup(file, "html.parser") |
| return soup.get_text() |
|
|
| |
| @st.cache_resource |
| def load_file(file_name, file_type): |
| loaders = [] |
| text = None |
|
|
| if file_type == "pdf": |
| loaders = [PyPDFLoader(file_name)] |
| elif file_type == "docx": |
| loader = DocxLoader(file_name) |
| text = loader.load() |
| elif file_type == "pptx": |
| loader = PptxLoader(file_name) |
| text = loader.load() |
| elif file_type == "txt": |
| loaders = [TextLoader(file_name)] |
| elif file_type == "csv": |
| text = load_csv(file_name) |
| elif file_type == "json": |
| text = load_json(file_name) |
| elif file_type == "xml": |
| text = load_xml(file_name) |
| elif file_type == "yaml": |
| text = load_yaml(file_name) |
| elif file_type == "html": |
| text = load_html(file_name) |
| elif file_type == "htm": |
| text = load_html(file_name) |
| else: |
| st.error("Unsupported file type.") |
| return None |
|
|
| if text: |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file: |
| temp_file.write(text.encode("utf-8")) |
| temp_file_path = temp_file.name |
| loaders = [TextLoader(temp_file_path)] |
|
|
| if loaders: |
| |
| docs = [] |
| for loader in loaders: |
| docs.extend(loader.load()) |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=450, chunk_overlap=50 |
| ) |
| splits = text_splitter.split_documents(docs) |
|
|
| |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2") |
| index = FAISS.from_documents(splits, embeddings) |
|
|
| st.success("Index created successfully!") |
| return index |
|
|
| return None |
|
|
| |
| load_dotenv() |
|
|
| watsonx_api_key = os.getenv("WATSONX_API_KEY") |
| watsonx_project_id = os.getenv("WATSONX_PROJECT_ID") |
|
|
| if not watsonx_api_key or not watsonx_project_id: |
| st.error("API Key or Project ID is not set. Please set them as environment variables.") |
|
|
| prompt_template_br = PromptTemplate( |
| input_variables=["context", "question"], |
| template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| I am a helpful assistant. |
| <|eot_id|> |
| {context} |
| <|start_header_id|>user<|end_header_id|> |
| {question}<|eot_id|> |
| """ |
| ) |
|
|
| with st.sidebar: |
| st.title("Multi-Document Retrieval with Watsonx") |
| st.sidebar.write("") |
| st.sidebar.markdown("Developed by **Abdul Rahman H**") |
|
|
| watsonx_model = st.selectbox( |
| "Model", |
| [ |
| "meta-llama/llama-3-405b-instruct", |
| "ibm/granite-34b-code-instruct", |
| "google/flan-t5-xxl", |
| ], |
| ) |
| max_new_tokens = st.slider( |
| "Max output tokens", min_value=100, max_value=4000, value=600, step=100 |
| ) |
| decoding_method = st.radio( |
| "Decoding", (DecodingMethods.GREEDY.value, DecodingMethods.SAMPLE.value) |
| ) |
| parameters = { |
| GenParams.DECODING_METHOD: decoding_method, |
| GenParams.MAX_NEW_TOKENS: max_new_tokens, |
| GenParams.MIN_NEW_TOKENS: 1, |
| GenParams.TEMPERATURE: 0, |
| GenParams.TOP_K: 50, |
| GenParams.TOP_P: 1, |
| GenParams.STOP_SEQUENCES: [], |
| GenParams.REPETITION_PENALTY: 1, |
| } |
|
|
| st.info("Upload a file to use RAG") |
| uploaded_file = st.file_uploader( |
| "Upload file", |
| accept_multiple_files=False, |
| type=["pdf", "docx", "txt", "pptx", "csv", "json", "xml", "yaml", "html"], |
| ) |
|
|
| if uploaded_file is not None: |
| bytes_data = uploaded_file.read() |
| st.write("Filename:", uploaded_file.name) |
| with open(uploaded_file.name, "wb") as f: |
| f.write(bytes_data) |
| file_type = uploaded_file.name.split(".")[-1].lower() |
| index = load_file(uploaded_file.name, file_type) |
|
|
| model_name = watsonx_model |
|
|
| st.info("Setting up Watsonx...") |
| my_credentials = { |
| "url": "https://us-south.ml.cloud.ibm.com", |
| "apikey": watsonx_api_key, |
| } |
| params = parameters |
| project_id = watsonx_project_id |
| space_id = None |
| verify = False |
| model = WatsonxLLM( |
| model=Model(model_name, my_credentials, params, project_id, space_id, verify) |
| ) |
|
|
| chain = None |
| if model: |
| st.info(f"Model {model_name} ready.") |
| chain = LLMChain(llm=model, prompt=prompt_template_br, verbose=True) |
|
|
| if chain and index is not None: |
| rag_chain = RetrievalQA.from_chain_type( |
| llm=model, |
| chain_type="stuff", |
| retriever=index.as_retriever(), |
| chain_type_kwargs={"prompt": prompt_template_br}, |
| return_source_documents=False, |
| verbose=True, |
| ) |
| st.info("Document-based retrieval is ready.") |
| else: |
| st.warning("No document uploaded or chain setup issue.") |
|
|
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
| for message in st.session_state.messages: |
| st.chat_message(message["role"]).markdown(message["content"]) |
|
|
| prompt = st.chat_input("Ask your question here", disabled=False if chain else True) |
|
|
| if prompt: |
| st.chat_message("user").markdown(prompt) |
| if rag_chain: |
| response_text = rag_chain.run(prompt).strip() |
| else: |
| response_text = chain.run(question=prompt, context="").strip() |
|
|
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| st.chat_message("assistant").markdown(response_text) |
| st.session_state.messages.append({"role": "assistant", "content": response_text}) |
|
|
|
|