Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import os | |
| import unicodedata | |
| import resources # Assuming this file exists in your repo | |
| import tracker | |
| import rag_engine # Now safe to import at top level (lazy loading enabled) | |
| from openai import OpenAI | |
| from datetime import datetime | |
| # --- CONFIGURATION --- | |
| st.set_page_config(page_title="Navy AI Toolkit", page_icon="β", layout="wide") | |
| # 1. SETUP CREDENTIALS | |
| API_URL_ROOT = os.getenv("API_URL") # For Ollama models | |
| OPENAI_KEY = os.getenv("OPENAI_API_KEY") # For GPT-4o | |
| # --- INITIALIZATION --- | |
| if "roles" not in st.session_state: | |
| st.session_state.roles = [] | |
| # --- LOGIN / REGISTER LOGIC --- | |
| if "authentication_status" not in st.session_state or st.session_state["authentication_status"] is None: | |
| # If not logged in, show tabs | |
| login_tab, register_tab = st.tabs(["π Login", "π Register"]) | |
| with login_tab: | |
| is_logged_in = tracker.check_login() | |
| # FIX: Trigger User DB Download ONLY on fresh login | |
| if is_logged_in: | |
| tracker.download_user_db(st.session_state.username) | |
| st.rerun() # Refresh to show the app | |
| with register_tab: | |
| st.header("Create Account") | |
| with st.form("reg_form"): | |
| new_user = st.text_input("Username") | |
| new_name = st.text_input("Display Name") | |
| new_email = st.text_input("Email") | |
| new_pwd = st.text_input("Password", type="password") | |
| invite = st.text_input("Invitation Passcode") | |
| if st.form_submit_button("Register"): | |
| success, msg = tracker.register_user(new_email, new_user, new_name, new_pwd, invite) | |
| if success: | |
| st.success(msg) | |
| else: | |
| st.error(msg) | |
| # Stop execution if not logged in | |
| if not st.session_state.get("authentication_status"): | |
| st.stop() | |
| # --- GLOBAL PLACEHOLDERS --- | |
| metric_placeholder = None | |
| admin_metric_placeholder = None | |
| # --- SIDEBAR (CONSOLIDATED) --- | |
| with st.sidebar: | |
| st.header("π€ User Profile") | |
| st.write(f"Welcome, **{st.session_state.name}**") | |
| st.header("π Usage Tracker") | |
| metric_placeholder = st.empty() | |
| # Admin Tools | |
| if "admin" in st.session_state.roles: | |
| st.divider() | |
| st.header("π‘οΈ Admin Tools") | |
| admin_metric_placeholder = st.empty() | |
| # FIX: Point to the correct persistence path | |
| log_path = tracker.get_log_path() | |
| if log_path.exists(): | |
| with open(log_path, "r") as f: | |
| log_data = f.read() | |
| st.download_button( | |
| label="π₯ Download Usage Logs", | |
| data=log_data, | |
| file_name=f"usage_log_{datetime.now().strftime('%Y-%m-%d')}.json", | |
| mime="application/json" | |
| ) | |
| else: | |
| st.warning("No logs found yet.") | |
| # Logout | |
| if "authenticator" in st.session_state: | |
| st.session_state.authenticator.logout(location='sidebar') | |
| st.divider() | |
| # --- MODEL SELECTOR --- | |
| st.header("π§ Model Selector") | |
| model_map = { | |
| "Granite 4 (IBM)": "granite4:latest", | |
| "Llama 3.2 (Meta)": "llama3.2:latest", | |
| "Gemma 3 (Google)": "gemma3:latest" | |
| } | |
| model_options = list(model_map.keys()) | |
| model_captions = ["Slower for now, but free and private" for _ in model_options] | |
| if "admin" in st.session_state.roles: | |
| model_options.append("GPT-4o (Omni)") | |
| model_captions.append("Fast, smart, sends data to OpenAI") | |
| model_choice = st.radio( | |
| "Choose your Intelligence:", | |
| model_options, | |
| captions=model_captions | |
| ) | |
| st.info(f"Connected to: **{model_choice}**") | |
| st.divider() | |
| st.header("βοΈ Controls") | |
| max_len = st.slider("Max Response Length (Tokens)", 100, 2000, 500) | |
| # --- HELPER FUNCTIONS --- | |
| def update_sidebar_metrics(): | |
| """Refreshes the global placeholders defined in the sidebar.""" | |
| if metric_placeholder is None: | |
| return | |
| stats = tracker.get_daily_stats() | |
| user_stats = stats["users"].get(st.session_state.username, {"input":0, "output":0}) | |
| metric_placeholder.metric("My Tokens Today", user_stats["input"] + user_stats["output"]) | |
| if "admin" in st.session_state.roles and admin_metric_placeholder is not None: | |
| admin_metric_placeholder.metric("Team Total Today", stats["total_tokens"]) | |
| # Call metrics once on load | |
| update_sidebar_metrics() | |
| def query_local_model(user_prompt, system_persona, max_tokens, model_name): | |
| if not API_URL_ROOT: | |
| return "Error: API_URL not set.", None | |
| url = API_URL_ROOT + "/generate" | |
| payload = { | |
| "text": user_prompt, | |
| "persona": system_persona, | |
| "max_tokens": max_tokens, | |
| "model": model_name | |
| } | |
| try: | |
| response = requests.post(url, json=payload, timeout=120) | |
| if response.status_code == 200: | |
| response_data = response.json() | |
| ans = response_data.get("response", "") | |
| usage = response_data.get("usage", {"input":0, "output":0}) | |
| return ans, usage | |
| return f"Error {response.status_code}: {response.text}", None | |
| except Exception as e: | |
| return f"Connection Error: {e}", None | |
| def query_gpt4o(prompt, persona, max_tokens): | |
| if not OPENAI_KEY: | |
| return "Error: OPENAI_API_KEY not set.", None | |
| client = OpenAI(api_key=OPENAI_KEY) | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| max_tokens=max_tokens, | |
| messages=[ | |
| {"role": "system", "content": persona}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.3 | |
| ) | |
| usage_obj = response.usage | |
| usage_dict = {"input": usage_obj.prompt_tokens, "output": usage_obj.completion_tokens} | |
| return response.choices[0].message.content, usage_dict | |
| except Exception as e: | |
| return f"OpenAI Error: {e}", None | |
| def clean_text(text): | |
| if not text: return "" | |
| text = unicodedata.normalize('NFKC', text) | |
| replacements = {'β': '"', 'β': '"', 'β': "'", 'β': "'", 'β': '-', 'β': '-', 'β¦': '...', '\u00a0': ' '} | |
| for old, new in replacements.items(): | |
| text = text.replace(old, new) | |
| return text.strip() | |
| def ask_ai(user_prompt, system_persona, max_tokens): | |
| if "GPT-4o" in model_choice: | |
| return query_gpt4o(user_prompt, system_persona, max_tokens) | |
| else: | |
| technical_name = model_map[model_choice] | |
| return query_local_model(user_prompt, system_persona, max_tokens, technical_name) | |
| # --- MAIN UI --- | |
| st.title("AI Toolkit") | |
| tab1, tab2, tab3, tab4 = st.tabs(["π§ Email Builder", "π¬ Chat Playground", "π οΈ Prompt Architect", "π Knowledge Base"]) | |
| # --- TAB 1: EMAIL BUILDER --- | |
| with tab1: | |
| st.header("Structured Email Generator") | |
| if "email_draft" not in st.session_state: | |
| st.session_state.email_draft = "" | |
| st.subheader("1. Define the Voice") | |
| style_mode = st.radio("How should the AI write?", ["Use a Preset Persona", "Mimic My Style"], horizontal=True) | |
| selected_persona_instruction = "" | |
| if style_mode == "Use a Preset Persona": | |
| persona_name = st.selectbox("Select a Persona", list(resources.TONE_LIBRARY.keys())) | |
| selected_persona_instruction = resources.TONE_LIBRARY[persona_name] | |
| st.info(f"**System Instruction:** {selected_persona_instruction}") | |
| else: | |
| st.info("Upload 1-3 text files of your previous emails.") | |
| uploaded_style_files = st.file_uploader("Upload Samples (.txt)", type=["txt"], accept_multiple_files=True) | |
| if uploaded_style_files: | |
| style_context = "" | |
| for uploaded_file in uploaded_style_files: | |
| string_data = uploaded_file.read().decode("utf-8") | |
| style_context += f"---\n{string_data}\n---\n" | |
| selected_persona_instruction = f"Analyze these examples and mimic the style:\n{style_context}" | |
| st.divider() | |
| st.subheader("2. Details") | |
| c1, c2 = st.columns(2) | |
| with c1: recipient = st.text_input("Recipient") | |
| with c2: topic = st.text_input("Topic") | |
| st.caption("Content Source") | |
| input_method = st.toggle("Upload notes file?") | |
| raw_notes = "" | |
| if input_method: | |
| notes_file = st.file_uploader("Upload Notes (.txt)", type=["txt"]) | |
| if notes_file: raw_notes = notes_file.read().decode("utf-8") | |
| else: | |
| raw_notes = st.text_area("Paste notes:", height=150) | |
| # Context Bar | |
| est_tokens = len(raw_notes) / 4 | |
| st.progress(min(est_tokens / 128000, 1.0), text=f"Context: {int(est_tokens)} tokens") | |
| if st.button("Draft Email", type="primary"): | |
| if not raw_notes: | |
| st.warning("Please provide notes.") | |
| else: | |
| clean_notes = clean_text(raw_notes) | |
| with st.spinner(f"Drafting with {model_choice}..."): | |
| prompt = f"TASK: Write email.\nTO: {recipient}\nTOPIC: {topic}\nSTYLE: {selected_persona_instruction}\nDATA: {clean_notes}" | |
| reply, usage = ask_ai(prompt, "You are an expert ghostwriter.", max_len) | |
| st.session_state.email_draft = reply | |
| if usage: | |
| m_name = "Granite" if "Granite" in model_choice else "GPT-4o" | |
| tracker.log_usage(m_name, usage["input"], usage["output"]) | |
| update_sidebar_metrics() # Force update | |
| if st.session_state.email_draft: | |
| st.subheader("Draft Result") | |
| st.text_area("Copy your email:", value=st.session_state.email_draft, height=300) | |
| # --- TAB 2: CHAT PLAYGROUND --- | |
| with tab2: | |
| st.header("Choose Your Model and Start a Discussion") | |
| if "chat_response" not in st.session_state: | |
| st.session_state.chat_response = "" | |
| user_input = st.text_input("Ask a question:") | |
| c1, c2 = st.columns([1,1]) | |
| with c1: | |
| use_rag = st.toggle("π Enable Knowledge Base", value=True) | |
| with c2: | |
| est_tokens = len(user_input) / 4 | |
| st.progress(min(est_tokens / 2000, 1.0), text=f"Input: {int(est_tokens)} tokens") | |
| if st.button("Send Query"): | |
| if not user_input: | |
| st.warning("Please enter a question.") | |
| else: | |
| final_prompt = user_input | |
| system_persona = "You are a helpful assistant." | |
| # --- RAG LOGIC --- | |
| if use_rag: | |
| with st.spinner("π§ Searching Knowledge Base..."): | |
| # 1. Retrieve & Rerank (Now using the fixed function) | |
| retrieved_docs = rag_engine.search_knowledge_base( | |
| user_input, | |
| st.session_state.username, | |
| k=3 | |
| ) | |
| if retrieved_docs: | |
| # 2. Format Context | |
| context_text = "" | |
| for i, doc in enumerate(retrieved_docs): | |
| # Add metadata relevance score if available | |
| score = doc.metadata.get('relevance_score', 'N/A') | |
| src = os.path.basename(doc.metadata.get('source', 'Unknown')) | |
| context_text += f"---\nSOURCE: {src} (Rel: {score})\nTEXT: {doc.page_content}\n" | |
| # 3. Update Prompt | |
| system_persona = ( | |
| "You are a Navy Document Analyst. " | |
| "Answer the user's question strictly based on the Context provided below. " | |
| "If the answer is not in the Context, state 'I cannot find that information in the provided documents.' \n\n" | |
| f"### CONTEXT:\n{context_text}" | |
| ) | |
| st.success(f"Found {len(retrieved_docs)} relevant documents.") | |
| with st.expander("View Context Used"): | |
| st.text(context_text) | |
| else: | |
| st.warning("No relevant documents found. Using general knowledge.") | |
| # --- GENERATION --- | |
| with st.spinner(f"Thinking with {model_choice}..."): | |
| reply, usage = ask_ai(final_prompt, system_persona, max_len) | |
| st.session_state.chat_response = reply | |
| if usage: | |
| m_name = "Granite" if "Granite" in model_choice else "GPT-4o" | |
| tracker.log_usage(m_name, usage["input"], usage["output"]) | |
| update_sidebar_metrics() | |
| if st.session_state.chat_response: | |
| st.divider() | |
| st.markdown("**AI Response:**") | |
| st.write(st.session_state.chat_response) | |
| # --- TAB 3: PROMPT ARCHITECT --- | |
| with tab3: | |
| st.header("π οΈ Mega-Prompt Factory") | |
| st.info("Build standard templates for NIPRGPT.") | |
| c1, c2 = st.columns([1,1]) | |
| with c1: | |
| st.subheader("1. Parameters") | |
| p = st.text_area("Persona", placeholder="Act as...", height=100) | |
| c = st.text_area("Context", placeholder="Background...", height=100) | |
| t = st.text_area("Task", placeholder="Action...", height=100) | |
| v = st.text_input("Placeholder Name", value="PASTE_DATA_HERE") | |
| with c2: | |
| st.subheader("2. Result") | |
| final = f"### ROLE\n{p}\n### CONTEXT\n{c}\n### TASK\n{t}\n### INPUT DATA\n\"\"\"\n[{v}]\n\"\"\"" | |
| st.code(final, language="markdown") | |
| st.download_button("πΎ Download .txt", final, "template.txt") | |
| # --- TAB 4: KNOWLEDGE BASE --- | |
| with tab4: | |
| st.header("π§ Unit Knowledge Base") | |
| is_admin = "admin" in st.session_state.roles | |
| kb_tab1, kb_tab2 = st.tabs(["π€ Add Documents", "ποΈ Manage Database"]) | |
| # --- SUB-TAB 1: UPLOAD --- | |
| with kb_tab1: | |
| if is_admin: | |
| st.subheader("Ingest New Knowledge") | |
| uploaded_file = st.file_uploader("Upload Instructions, Manuals, or Logs", type=["pdf", "docx", "txt", "md"]) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| chunk_strategy = st.selectbox( | |
| "Chunking Strategy", | |
| ["paragraph", "token", "page"], | |
| help="Paragraph: Manuals. Token: Dense text. Page: Forms." | |
| ) | |
| if uploaded_file and st.button("Process & Add"): | |
| with st.spinner("Analyzing and Indexing..."): | |
| # Use safe save + process | |
| temp_path = rag_engine.save_uploaded_file(uploaded_file) | |
| success, msg = rag_engine.process_and_add_document( | |
| temp_path, | |
| st.session_state.username, | |
| chunk_strategy | |
| ) | |
| if success: | |
| st.success(msg) | |
| st.rerun() | |
| else: | |
| st.error(f"Failed: {msg}") | |
| else: | |
| st.info("π Only Admins can upload documents.") | |
| st.divider() | |
| st.subheader("π Quick Test") | |
| test_query = st.text_input("Ask the brain something...") | |
| if test_query: | |
| results = rag_engine.search_knowledge_base(test_query, st.session_state.username) | |
| for i, doc in enumerate(results): | |
| # Using cleaned safe basename | |
| src_name = os.path.basename(doc.metadata.get('source', '?')) | |
| score = doc.metadata.get('relevance_score', 'N/A') | |
| with st.expander(f"Match {i+1}: {src_name} (Score: {score})"): | |
| st.write(doc.page_content) | |
| # --- SUB-TAB 2: MANAGE --- | |
| with kb_tab2: | |
| st.subheader("ποΈ Database Inventory") | |
| # 1. Fetch current docs | |
| docs = rag_engine.list_documents(st.session_state.username) | |
| if not docs: | |
| st.info("Knowledge Base is empty.") | |
| else: | |
| st.markdown(f"**Total Documents:** {len(docs)}") | |
| for doc in docs: | |
| c1, c2, c3 = st.columns([3, 1, 1]) | |
| with c1: | |
| st.text(f"π {doc['filename']}") | |
| with c2: | |
| st.caption(f"{doc['chunks']} chunks") | |
| with c3: | |
| if is_admin: | |
| if st.button("ποΈ Delete", key=doc['source']): | |
| with st.spinner("Deleting..."): | |
| success, msg = rag_engine.delete_document(st.session_state.username, doc['source']) | |
| if success: | |
| st.success(msg) | |
| st.rerun() | |
| else: | |
| st.error(msg) | |
| else: | |
| st.caption("Read Only") | |
| if is_admin and docs: | |
| st.divider() | |
| with st.expander("π¨ Danger Zone"): | |
| if st.button("β’οΈ RESET ENTIRE DATABASE", type="primary"): | |
| success, msg = rag_engine.reset_knowledge_base(st.session_state.username) | |
| if success: | |
| st.success(msg) | |
| st.rerun() |