Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import bm25s | |
| from operator import itemgetter | |
| import os | |
| import re | |
| import pandas as pd | |
| from langchain_groq import ChatGroq | |
| from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate | |
| from langchain.docstore.document import Document | |
| def load_data(): | |
| df = pd.read_csv("cleaned_list.csv",header = None) | |
| df.columns = ['document'] | |
| corpus = [doc for doc in df['document'].to_list()] | |
| retriever = bm25s.BM25(corpus=corpus) | |
| retriever.index(bm25s.tokenize(corpus)) | |
| return retriever | |
| # def extract_hscode(text): | |
| # match = re.search(r'hs_code:\s*(\d+)', text) | |
| # if match: | |
| # return match.group(1) | |
| # return None | |
| # df2 = pd.read_csv("hscode_main.csv") | |
| # new_col = [len(str(code))for code in df2['hs_code'].to_list()] | |
| # df2['len'] = new_col | |
| # new_hscode = [str(code) for code in df2['hs_code']] | |
| # for i in range(len(new_col)): | |
| # if new_col[i]==5: | |
| # new_hscode[i] = '0'+ new_hscode[i] | |
| # df2['hs_code'] = new_hscode | |
| # df2=df2.drop(columns='len') | |
| # if 'retriever' not in st.session_state: | |
| # st.session_state.retriever = None | |
| # if st.session_state.retriever is None: | |
| # st.session_state.retriever = load_data() | |
| # sentence = st.text_input("please enter description:") | |
| # if sentence !='': | |
| # results,_ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=5) | |
| # doc = [d for d in results] | |
| # hscodes = [extract_hscode(item) for item in doc[0]] | |
| # for code in hscodes: | |
| # if len(code)==5: | |
| # code = '0'+ code | |
| # filter_df = df2[df2['hs_code']==code] | |
| # answer = filter_df['description'].iloc[0] | |
| # st.write("Hscode:",code) | |
| # st.write("Description:",answer.lower()) | |
| def load_model(): | |
| prompt = ChatPromptTemplate.from_messages([ | |
| HumanMessagePromptTemplate.from_template( | |
| f""" | |
| Extract the appropriate 6-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results. | |
| Only return the HS Code as a 6-digit number . | |
| Example: 123456 | |
| Context: {{context}} | |
| Description: {{description}} | |
| Answer: | |
| """ | |
| ) | |
| ]) | |
| #device = "cuda" if torch.cuda.is_available() else "cpu" | |
| #llm = OllamaLLM(model="gemma2", temperature=0, device=device) | |
| #api_key = "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" | |
| api_key = "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" | |
| llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key) | |
| chain = prompt|llm | |
| return chain | |
| def process_input(sentence): | |
| docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15) | |
| documents =[] | |
| for doc in docs[0]: | |
| documents.append(Document(doc)) | |
| return documents | |
| if 'retriever' not in st.session_state: | |
| st.session_state.retriever = None | |
| if 'chain' not in st.session_state: | |
| st.session_state.chain = None | |
| if st.session_state.retriever is None: | |
| st.session_state.retriever = load_data() | |
| if st.session_state.chain is None: | |
| st.session_state.chain = load_model() | |
| sentence = st.text_input("please enter description:") | |
| if sentence !='': | |
| documents = process_input(sentence) | |
| hscode = st.session_state.chain.invoke({'context': documents,'description':sentence}) | |
| st.write("answer:",hscode.content) |