Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import os | |
| import html | |
| import re | |
| import torch | |
| from datasets import load_dataset | |
| from langchain_ollama import OllamaLLM | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.docstore.document import Document | |
| from langchain_openai import ChatOpenAI | |
| from langchain_together import Together | |
| from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
| # 🔹 Load and inject custom CSS | |
| def load_css(file_name): | |
| with open(file_name) as f: | |
| st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
| load_css("styles.css") | |
| # 🔹 System prompt | |
| system_prompt = """ | |
| You are Adrega AI, a helpful assistant trained to answer questions based on internal company documentation. | |
| Instructions: | |
| - Be concise and professional and keep response length under 512 tokens if possible. | |
| - Use markdown formatting when helpful. | |
| - Never show the thought process or reasoning steps. | |
| - If the user asks about a specific operation in the product, explain step by step how to do it. | |
| - If the user greets you (e.g. “Hi”, “Hello”), respond with a friendly greeting and offer help. Do not reference documentation unless a question is asked. | |
| - Reference the Help page from the hamburger menu when needed. | |
| - The Field Reference module is not a physical module, but a description of fields in Adrega. Prioritize it when describing fields. | |
| - Do not make assumptions. Use the manual as reference. | |
| """ | |
| # 🔹 Chat prompt template | |
| chat_prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessagePromptTemplate.from_template(system_prompt), | |
| HumanMessagePromptTemplate.from_template(""" | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| """) | |
| ]) | |
| # 🔹 Provider selection | |
| #provider_options = ["Together", "HuggingFace", "Adrega"] | |
| provider_options = ["Adrega", "Together"] | |
| selected_provider = st.selectbox("Provider", provider_options) | |
| # 🔹 Load and process dataset once | |
| if "initialized" not in st.session_state: | |
| dataset = load_dataset("andreska/Adrega61Docs", split="train") | |
| def read_dataset(dataset): | |
| return "\n\n".join([ | |
| f"Title: {item['title']}\nModule: {item['module']}\nContent: {item['content']}" | |
| for item in dataset | |
| ]) | |
| raw_text = read_dataset(dataset) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| docs = text_splitter.split_documents([Document(page_content=raw_text)]) | |
| # Smart device detection for HuggingFace Spaces | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': device} | |
| ) | |
| db = FAISS.from_documents(docs, embedding_model) | |
| st.session_state.db = db | |
| st.session_state.docs = docs | |
| st.session_state.initialized = True | |
| # 🔹 Provider-specific LLM setup | |
| def get_llm(provider): | |
| # Validate provider input against allowed options | |
| allowed_providers = ["Adrega", "Together", "HuggingFace"] | |
| if provider not in allowed_providers: | |
| provider = "Adrega" # Default fallback | |
| if provider == "Together": | |
| return Together( | |
| # Alternative models available for testing | |
| #model="meta-llama/Llama-3.2-11B-Vision-Free", | |
| model="deepseek-ai/DeepSeek-R1-Distilled", | |
| #model="flux-ai/FLUX.1-schnell-Free", | |
| max_tokens=2048 | |
| ) | |
| elif provider == "HuggingFace": | |
| return ChatOpenAI( | |
| base_url="https://router.huggingface.co/v1", | |
| api_key=os.environ["HF_API_KEY"], | |
| model="HuggingFaceTB/SmolLM3-3B:hf-inference", | |
| ) | |
| else: # Adrega | |
| return OllamaLLM( | |
| base_url=os.environ["ADREGA_URL"], | |
| #model="phi3:mini", # Alternative lightweight model | |
| model="llama3:latest", | |
| streaming=False, | |
| ) | |
| # 🔹 Build the LLM and retriever | |
| llm = get_llm(selected_provider) | |
| retriever = st.session_state.db.as_retriever(search_kwargs={"k": 5}) | |
| # 🔹 UI Styling | |
| st.markdown("""<style> | |
| .scrollable-div { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
| max-height: 400px; | |
| overflow-y: auto; | |
| padding: 10px; | |
| border: 1px solid #ccc; | |
| background-color: #f9f9f9; | |
| } | |
| input[data-testid="stTextInput"] { cursor: pointer; } | |
| input[data-testid="stTextInput"]:focus { cursor: text; } | |
| </style> | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function() { | |
| const observer = new MutationObserver(function(mutations) { | |
| const input = document.querySelector('input[data-testid="stTextInput"]'); | |
| if (input && !input.hasAttribute('data-select-all')) { | |
| input.setAttribute('data-select-all', 'true'); | |
| input.addEventListener('click', function() { | |
| this.select(); | |
| }); | |
| } | |
| }); | |
| observer.observe(document.body, { childList: true, subtree: true }); | |
| }); | |
| </script>""", unsafe_allow_html=True) | |
| # 🔹 Initialize chat history | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [{"role": "assistant", "content": "Hi! I am your Adrega AI assistant. How can I help you today?"}] | |
| # 🔹 Build limited context from history | |
| def build_context(history, max_chars=10000): | |
| context = [] | |
| total_chars = 0 | |
| for msg in reversed(history): | |
| if msg["role"] == "user": | |
| entry = f"You: {msg['content']}\n" | |
| else: | |
| entry = f"Adrega AI: {msg['content']}\n" | |
| total_chars += len(entry) | |
| if total_chars > max_chars: | |
| break | |
| context.insert(0, entry) | |
| return "\n".join(context) | |
| def render_chat(): | |
| html_content = """ | |
| <style> | |
| .scrollable-div { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
| max-height: 400px; | |
| overflow-y: auto; | |
| padding: 10px; | |
| border: 1px solid #ccc; | |
| background-color: #f9f9f9; | |
| } | |
| </style> | |
| <div class="scrollable-div" id="chat-box"> | |
| """ | |
| for msg in st.session_state.chat_history: | |
| content = msg["content"] | |
| # Decode HTML entities first | |
| content = html.unescape(content) | |
| # Convert markdown to HTML | |
| content = content.replace('\n', '<br>') # Line breaks | |
| content = content.replace('**', '<b>', 1).replace('**', '</b>', 1) | |
| while '**' in content: | |
| content = content.replace('**', '<b>', 1).replace('**', '</b>', 1) | |
| # Convert lists to HTML | |
| lines = content.split('<br>') | |
| processed_lines = [] | |
| in_ul_list = False | |
| in_ol_list = False | |
| for line in lines: | |
| stripped = line.strip() | |
| # Check for numbered lists (1. 2. etc.) | |
| if re.match(r'^\d+\.\s+', stripped): | |
| if in_ul_list: | |
| processed_lines.append('</ul>') | |
| in_ul_list = False | |
| if not in_ol_list: | |
| processed_lines.append('<ol>') | |
| in_ol_list = True | |
| item_text = re.sub(r'^\d+\.\s+', '', stripped) | |
| processed_lines.append(f'<li>{item_text}</li>') | |
| # Check for bullet lists | |
| elif stripped.startswith('- ') or stripped.startswith('* '): | |
| if in_ol_list: | |
| processed_lines.append('</ol>') | |
| in_ol_list = False | |
| if not in_ul_list: | |
| processed_lines.append('<ul>') | |
| in_ul_list = True | |
| processed_lines.append(f'<li>{stripped[2:]}</li>') | |
| else: | |
| if in_ul_list: | |
| processed_lines.append('</ul>') | |
| in_ul_list = False | |
| if in_ol_list: | |
| processed_lines.append('</ol>') | |
| in_ol_list = False | |
| processed_lines.append(line) | |
| if in_ul_list: | |
| processed_lines.append('</ul>') | |
| if in_ol_list: | |
| processed_lines.append('</ol>') | |
| content = ''.join(processed_lines) | |
| if msg["role"] == "user": | |
| html_content += f'<div style="background-color:#e0e0e0; padding:8px; border-radius:5px; margin-bottom:6px;"><b>You:</b> {content}</div>' | |
| else: | |
| html_content += f'<div style="padding:8px; margin-bottom:6px;"><b>Adrega AI:</b> {content}</div>' | |
| html_content += """ | |
| </div> | |
| <script> | |
| const chatBox = document.getElementById("chat-box"); | |
| if (chatBox) { | |
| chatBox.scrollTop = chatBox.scrollHeight; | |
| } | |
| </script> | |
| """ | |
| components.html(html_content, height=420, scrolling=False) | |
| # 🔹 Handle user input | |
| def handle_submit(): | |
| user_input = st.session_state.user_input | |
| if user_input: | |
| context = build_context(st.session_state.chat_history) | |
| # Check if user needs documentation (questions or specific requests) | |
| needs_docs = any(word in user_input.lower() for word in ['how', 'what', 'where', 'when', 'why', '?']) or len(user_input.split()) > 5 | |
| if needs_docs: | |
| # Retrieve relevant docs for complex queries | |
| docs = retriever.get_relevant_documents(user_input) | |
| doc_context = "\n\n".join([doc.page_content for doc in docs[:3]]) | |
| if selected_provider == "Together": | |
| full_prompt = f"{system_prompt}\n\nConversation History:\n{context}\n\nDocumentation Context:\n{doc_context}\n\nUser: {user_input}\nAdrega AI:" | |
| else: | |
| full_prompt = f"{system_prompt}\n\nConversation History:\n{context}\n\nDocumentation Context:\n{doc_context}\n\nUser: {user_input}\n\nRespond directly as Adrega AI without showing your thought process:" | |
| else: | |
| # Simple conversational response without docs | |
| if selected_provider == "Together": | |
| full_prompt = f"{system_prompt}\n\nConversation History:\n{context}\n\nUser: {user_input}\nAdrega AI:" | |
| else: | |
| full_prompt = f"{system_prompt}\n\nConversation History:\n{context}\n\nUser: {user_input}\n\nAnswer directly without any reasoning or explanation of your thought process:" | |
| answer = llm.invoke(full_prompt) | |
| # Extract content from response object | |
| if hasattr(answer, 'content'): | |
| answer = answer.content | |
| elif isinstance(answer, dict) and 'content' in answer: | |
| answer = answer['content'] | |
| # Clean up response for non-Together providers | |
| if selected_provider != "Together" and isinstance(answer, str): | |
| # Remove instruction-following sentences at the beginning | |
| sentences = answer.split('. ') | |
| clean_sentences = [] | |
| for sentence in sentences: | |
| # Skip sentences that are clearly instruction-following | |
| if not any(phrase in sentence.lower() for phrase in [ | |
| 'keep the response', 'as per the instructions', 'per the instructions', | |
| 'following the instructions', 'according to', 'make sure to', | |
| 'remember to', 'be sure to', 'don\'t forget to' | |
| ]): | |
| clean_sentences.append(sentence.strip()) | |
| if clean_sentences: | |
| answer = '. '.join(clean_sentences) | |
| # Clean up any remaining artifacts | |
| if answer.endswith('.'): | |
| answer = answer[:-1] + '.' | |
| # Additional cleanup for remaining reasoning patterns | |
| lines = answer.split('\n') | |
| clean_lines = [] | |
| for line in lines: | |
| if not any(phrase in line.lower() for phrase in ['okay,', 'let me', 'looking at', 'i need', 'i should', 'wait,', 'so the response', 'therefore']): | |
| clean_lines.append(line) | |
| answer = '\n'.join(clean_lines).strip() | |
| # Update history | |
| st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
| st.session_state.chat_history.append({"role": "assistant", "content": answer}) | |
| # 🔹 Create container for chat display | |
| chat_container = st.container() | |
| # 🔹 Input field | |
| st.text_input('Ask me a question', key='user_input', on_change=handle_submit) | |
| if st.button("Ask"): | |
| handle_submit() | |
| # 🔹 Display chat history in container | |
| with chat_container: | |
| render_chat() | |