Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import os | |
| import unicodedata | |
| import resources | |
| import tracker | |
| import rag_engine | |
| from openai import OpenAI | |
| from datetime import datetime | |
| # --- CONFIGURATION --- | |
| st.set_page_config(page_title="Navy AI Toolkit", page_icon="β", layout="wide") | |
| # 1. SETUP CREDENTIALS | |
| API_URL_ROOT = os.getenv("API_URL") # For Ollama models | |
| OPENAI_KEY = os.getenv("OPENAI_API_KEY") # For GPT-4o | |
| # --- INITIALIZATION --- | |
| if "roles" not in st.session_state: | |
| st.session_state.roles = [] | |
| # --- LOGIN / REGISTER LOGIC --- | |
| if "authentication_status" not in st.session_state or st.session_state["authentication_status"] is None: | |
| # If not logged in, show tabs | |
| login_tab, register_tab = st.tabs(["π Login", "π Register"]) | |
| with login_tab: | |
| is_logged_in = tracker.check_login() | |
| if is_logged_in: | |
| # Check if a different user was previously logged in | |
| if "last_user" in st.session_state and st.session_state.last_user != st.session_state.username: | |
| # WIPE EVERYTHING | |
| st.session_state.messages = [] | |
| st.session_state.email_draft = "" | |
| st.session_state.user_openai_key = None | |
| # Update the tracker | |
| st.session_state.last_user = st.session_state.username | |
| # Download DB and Refresh | |
| tracker.download_user_db(st.session_state.username) | |
| st.rerun() # Refresh to show the app | |
| with register_tab: | |
| st.header("Create Account") | |
| with st.form("reg_form"): | |
| new_user = st.text_input("Username") | |
| new_name = st.text_input("Display Name") | |
| new_email = st.text_input("Email") | |
| new_pwd = st.text_input("Password", type="password") | |
| invite = st.text_input("Invitation Passcode") | |
| if st.form_submit_button("Register"): | |
| success, msg = tracker.register_user(new_email, new_user, new_name, new_pwd, invite) | |
| if success: | |
| st.success(msg) | |
| else: | |
| st.error(msg) | |
| # Stop execution if not logged in | |
| if not st.session_state.get("authentication_status"): | |
| st.stop() | |
| # --- GLOBAL PLACEHOLDERS --- | |
| metric_placeholder = None | |
| admin_metric_placeholder = None | |
| # --- SIDEBAR (CONSOLIDATED) --- | |
| with st.sidebar: | |
| st.header("π€ User Profile") | |
| st.write(f"Welcome, **{st.session_state.name}**") | |
| st.header("π Usage Tracker") | |
| metric_placeholder = st.empty() | |
| # Admin Tools | |
| if "admin" in st.session_state.roles: | |
| st.divider() | |
| st.header("π‘οΈ Admin Tools") | |
| admin_metric_placeholder = st.empty() | |
| log_path = tracker.get_log_path() | |
| if log_path.exists(): | |
| with open(log_path, "r") as f: | |
| log_data = f.read() | |
| st.download_button( | |
| label="π₯ Download Usage Logs", | |
| data=log_data, | |
| file_name=f"usage_log_{datetime.now().strftime('%Y-%m-%d')}.json", | |
| mime="application/json" | |
| ) | |
| else: | |
| st.warning("No logs found yet.") | |
| # Logout | |
| if "authenticator" in st.session_state: | |
| st.session_state.authenticator.logout(location='sidebar') | |
| st.divider() | |
| # --- MODEL SELECTOR --- | |
| st.header("π§ Model Selector") | |
| model_map = { | |
| "Granite 4 (IBM)": "granite4:latest", | |
| "Llama 3.2 (Meta)": "llama3.2:latest", | |
| "Gemma 3 (Google)": "gemma3:latest" | |
| } | |
| model_options = list(model_map.keys()) | |
| model_captions = ["Slower for now, but free and private" for _ in model_options] | |
| # 2. CHECK FOR GPT-4o ACCESS (Admin OR User Key) | |
| # We moved the input UP so the user can unlock the option immediately | |
| # Check if user is admin | |
| is_admin = "admin" in st.session_state.roles | |
| # Input for Non-Admins | |
| user_api_key = None | |
| if not is_admin: | |
| user_api_key = st.text_input( | |
| "π Unlock GPT-4o (Enter API Key)", | |
| type="password", | |
| help="Enter your OpenAI API key to access GPT-4o. Press Enter to apply.", | |
| key=f"user_key_{st.session_state.username}" | |
| ) | |
| if user_api_key: | |
| st.session_state.user_openai_key = user_api_key | |
| st.caption("β Key Active") | |
| else: | |
| st.session_state.user_openai_key = None | |
| else: | |
| st.session_state.user_openai_key = None | |
| # 3. DYNAMICALLY ADD GPT-4o TO THE LIST | |
| # If Admin OR if they just entered a key, show the option | |
| if is_admin or st.session_state.get("user_openai_key"): | |
| model_options.append("GPT-4o (Omni)") | |
| model_captions.append("Fast, smart, sends data to OpenAI") | |
| # 4. RENDER THE SELECTOR | |
| model_choice = st.radio( | |
| "Choose your Intelligence:", | |
| model_options, | |
| captions=model_captions, | |
| key="model_selector_radio" | |
| ) | |
| st.info(f"Connected to: **{model_choice}**") | |
| st.divider() | |
| st.header("βοΈ Controls") | |
| max_len = st.slider("Max Response Length (Tokens)", 100, 2000, 500) | |
| # --- HELPER FUNCTIONS --- | |
| def update_sidebar_metrics(): | |
| """Refreshes the global placeholders defined in the sidebar.""" | |
| if metric_placeholder is None: | |
| return | |
| stats = tracker.get_daily_stats() | |
| user_stats = stats["users"].get(st.session_state.username, {"input":0, "output":0}) | |
| metric_placeholder.metric("My Tokens Today", user_stats["input"] + user_stats["output"]) | |
| if "admin" in st.session_state.roles and admin_metric_placeholder is not None: | |
| admin_metric_placeholder.metric("Team Total Today", stats["total_tokens"]) | |
| # Call metrics once on load | |
| update_sidebar_metrics() | |
| def query_local_model(messages, max_tokens, model_name): | |
| if not API_URL_ROOT: | |
| return "Error: API_URL not set.", None | |
| url = API_URL_ROOT + "/generate" | |
| # --- FLATTEN MESSAGE HISTORY --- | |
| formatted_history = "" | |
| system_persona = "You are a helpful assistant." # Default | |
| for msg in messages: | |
| if msg['role'] == 'system': | |
| system_persona = msg['content'] | |
| elif msg['role'] == 'user': | |
| formatted_history += f"User: {msg['content']}\n" | |
| elif msg['role'] == 'assistant': | |
| formatted_history += f"Assistant: {msg['content']}\n" | |
| # Append the "Assistant:" prompt at the end to cue the model | |
| formatted_history += "Assistant: " | |
| payload = { | |
| "text": formatted_history, | |
| "persona": system_persona, | |
| "max_tokens": max_tokens, | |
| "model": model_name | |
| } | |
| try: | |
| response = requests.post(url, json=payload, timeout=300) | |
| if response.status_code == 200: | |
| response_data = response.json() | |
| ans = response_data.get("response", "") | |
| usage = response_data.get("usage", {"input":0, "output":0}) | |
| return ans, usage | |
| return f"Error {response.status_code}: {response.text}", None | |
| except Exception as e: | |
| return f"Connection Error: {e}", None | |
| def query_openai_model(messages, max_tokens): | |
| # 1. Check for User Key first | |
| api_key_to_use = st.session_state.get("user_openai_key") | |
| # 2. Fallback to System Key | |
| if not api_key_to_use: | |
| api_key_to_use = OPENAI_KEY | |
| # 3. Final Safety Check | |
| if not api_key_to_use: | |
| return "Error: No API Key available. Please enter one in the sidebar.", None | |
| client = OpenAI(api_key=api_key_to_use) | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| max_tokens=max_tokens, | |
| messages=messages, | |
| temperature=0.3 | |
| ) | |
| usage_obj = response.usage | |
| usage_dict = {"input": usage_obj.prompt_tokens, "output": usage_obj.completion_tokens} | |
| return response.choices[0].message.content, usage_dict | |
| except Exception as e: | |
| return f"OpenAI Error: {e}", None | |
| def clean_text(text): | |
| if not text: return "" | |
| text = unicodedata.normalize('NFKC', text) | |
| replacements = {'β': '"', 'β': '"', 'β': "'", 'β': "'", 'β': '-', 'β': '-', 'β¦': '...', '\u00a0': ' '} | |
| for old, new in replacements.items(): | |
| text = text.replace(old, new) | |
| return text.strip() | |
| def ask_ai(user_prompt, system_persona, max_tokens): | |
| # 1. Standardize Input: Convert the strings into the Message List format | |
| messages_payload = [ | |
| {"role": "system", "content": system_persona}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| # 2. Routing Logic | |
| if "GPT-4o" in model_choice: | |
| return query_openai_model(messages_payload, max_tokens) | |
| else: | |
| technical_name = model_map[model_choice] | |
| return query_local_model(messages_payload, max_tokens, technical_name) | |
| # --- MAIN UI --- | |
| st.title("AI Toolkit") | |
| tab1, tab2, tab3, tab4 = st.tabs(["π§ Email Builder", "π¬ Chat Playground", "π οΈ Prompt Architect", "π Knowledge Base"]) | |
| # --- TAB 1: EMAIL BUILDER --- | |
| with tab1: | |
| st.header("Structured Email Generator") | |
| if "email_draft" not in st.session_state: | |
| st.session_state.email_draft = "" | |
| st.subheader("1. Define the Voice") | |
| style_mode = st.radio("How should the AI write?", ["Use a Preset Persona", "Mimic My Style"], horizontal=True) | |
| selected_persona_instruction = "" | |
| if style_mode == "Use a Preset Persona": | |
| persona_name = st.selectbox("Select a Persona", list(resources.TONE_LIBRARY.keys())) | |
| selected_persona_instruction = resources.TONE_LIBRARY[persona_name] | |
| st.info(f"**System Instruction:** {selected_persona_instruction}") | |
| else: | |
| st.info("Upload 1-3 text files of your previous emails.") | |
| uploaded_style_files = st.file_uploader("Upload Samples (.txt)", type=["txt"], accept_multiple_files=True) | |
| if uploaded_style_files: | |
| style_context = "" | |
| for uploaded_file in uploaded_style_files: | |
| string_data = uploaded_file.read().decode("utf-8") | |
| style_context += f"---\n{string_data}\n---\n" | |
| selected_persona_instruction = f"Analyze these examples and mimic the style:\n{style_context}" | |
| st.divider() | |
| st.subheader("2. Details") | |
| c1, c2 = st.columns(2) | |
| with c1: recipient = st.text_input("Recipient") | |
| with c2: topic = st.text_input("Topic") | |
| st.caption("Content Source") | |
| input_method = st.toggle("Upload notes file?") | |
| raw_notes = "" | |
| if input_method: | |
| notes_file = st.file_uploader("Upload Notes (.txt)", type=["txt"]) | |
| if notes_file: raw_notes = notes_file.read().decode("utf-8") | |
| else: | |
| raw_notes = st.text_area("Paste notes:", height=150) | |
| # Context Bar | |
| est_tokens = len(raw_notes) / 4 | |
| st.progress(min(est_tokens / 128000, 1.0), text=f"Context: {int(est_tokens)} tokens") | |
| if st.button("Draft Email", type="primary"): | |
| if not raw_notes: | |
| st.warning("Please provide notes.") | |
| else: | |
| clean_notes = clean_text(raw_notes) | |
| with st.spinner(f"Drafting with {model_choice}..."): | |
| prompt = f"TASK: Write email.\nTO: {recipient}\nTOPIC: {topic}\nSTYLE: {selected_persona_instruction}\nDATA: {clean_notes}" | |
| reply, usage = ask_ai(prompt, "You are an expert ghostwriter.", max_len) | |
| st.session_state.email_draft = reply | |
| if usage: | |
| if "GPT-4o" in model_choice: | |
| m_name = "GPT-4o" | |
| else: | |
| m_name = model_choice.split(" ")[0] | |
| tracker.log_usage(m_name, usage["input"], usage["output"]) | |
| update_sidebar_metrics() | |
| if st.session_state.email_draft: | |
| st.subheader("Draft Result") | |
| st.text_area("Copy your email:", value=st.session_state.email_draft, height=300) | |
| # --- TAB 2: CHAT PLAYGROUND --- | |
| with tab2: | |
| st.header("Choose Your Model and Start a Discussion") | |
| # --- INITIALIZE CHAT MEMORY (MUST BE DONE FIRST) --- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # --- CONTROLS AND METRICS --- | |
| c1, c2, c3 = st.columns([2, 1, 1]) | |
| with c1: | |
| # FIX: Access the correct key from the sidebar widget | |
| # We default to the global variable 'model_choice' if state is missing | |
| selected_model_name = st.session_state.get('model_selector_radio', model_choice) | |
| st.caption(f"Active Model: **{selected_model_name}**") | |
| with c2: | |
| use_rag = st.toggle("π Enable Knowledge Base", value=False) | |
| with c3: | |
| # --- NEW FEATURE: DOWNLOAD CHAT --- | |
| chat_log = "" | |
| for msg in st.session_state.messages: | |
| role = "USER" if msg['role'] == 'user' else "ASSISTANT" | |
| chat_log += f"[{role}]: {msg['content']}\n\n" | |
| if chat_log: | |
| st.download_button( | |
| label="πΎ Save Chat", | |
| data=chat_log, | |
| file_name="mission_log.txt", | |
| mime="text/plain", | |
| help="Download the current conversation history." | |
| ) | |
| st.divider() | |
| # --- DISPLAY CONVERSATION HISTORY --- | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # --- CHAT INPUT HANDLING --- | |
| if prompt := st.chat_input("Ask a question..."): | |
| # 1. Display User Message and save to history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # 2. Default Configuration (Standard AI Mode) | |
| system_persona = "You are a helpful AI assistant. Answer the user's question to the best of your ability." | |
| final_user_content = prompt | |
| retrieved_docs = [] | |
| # 3. Handle RAG Logic (Only if enabled) | |
| if use_rag: | |
| with st.spinner("π§ Searching Knowledge Base..."): | |
| retrieved_docs = rag_engine.search_knowledge_base( | |
| prompt, | |
| st.session_state.username | |
| ) | |
| if retrieved_docs: | |
| # RAG SUCCESS: Switch to Strict Navy Persona | |
| system_persona = ( | |
| "You are a Navy Document Analyst. Your task is to answer the user's question " | |
| "using ONLY the Context provided below. " | |
| "If the answer is not present in the Context, return ONLY this exact phrase: " | |
| "'I cannot find that information in the provided documents.'" | |
| ) | |
| # Format Context | |
| context_text = "" | |
| for doc in retrieved_docs: | |
| score = doc.metadata.get('relevance_score', 'N/A') | |
| src = os.path.basename(doc.metadata.get('source', 'Unknown')) | |
| context_text += f"---\nSOURCE: {src} (Rel: {score})\nTEXT: {doc.page_content}\n" | |
| # Augment User Prompt | |
| final_user_content = ( | |
| f"User Question: {prompt}\n\n" | |
| f"Relevant Context:\n{context_text}\n\n" | |
| "Answer the question using the context provided." | |
| ) | |
| # 4. Construct Payload (Now using the CORRECT persona) | |
| messages_payload = [{"role": "system", "content": system_persona}] | |
| # --- MEMORY LOGIC: SLIDING WINDOW --- | |
| history_depth = 8 | |
| recent_history = st.session_state.messages[-(history_depth+1):-1] | |
| messages_payload.extend(recent_history) | |
| # Add the final (potentially augmented) user message to payload | |
| messages_payload.append({"role": "user", "content": final_user_content}) | |
| # 5. Generate Response | |
| with st.chat_message("assistant"): | |
| with st.spinner(f"Thinking with {selected_model_name}..."): | |
| # Determine model ID | |
| model_id = "" | |
| ollama_map = { | |
| "Granite 4 (IBM)": "granite4:latest", | |
| "Llama 3.2 (Meta)": "llama3.2:latest", | |
| "Gemma 3 (Google)": "gemma3:latest" | |
| } | |
| for key, val in ollama_map.items(): | |
| if key in selected_model_name: | |
| model_id = val | |
| break | |
| # ROUTING CHECK | |
| if not model_id and "gpt" in selected_model_name.lower(): | |
| # If it's the GPT model choice | |
| response, usage = query_openai_model(messages_payload, max_len) | |
| elif model_id: | |
| # If it's the local Ollama model | |
| response, usage = query_local_model(messages_payload, max_len, model_id) | |
| else: | |
| response, usage = "Error: Could not determine model to use.", None | |
| st.markdown(response) | |
| # 6. Save Assistant Response | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| # 7. Metrics & Context Display | |
| if usage: | |
| if "GPT-4o" in selected_model_name: | |
| m_name = "GPT-4o" | |
| else: | |
| m_name = selected_model_name.split(" ")[0] | |
| tracker.log_usage(m_name, usage["input"], usage["output"]) | |
| update_sidebar_metrics() | |
| if use_rag and retrieved_docs: | |
| with st.expander("π View Context Used"): | |
| for i, doc in enumerate(retrieved_docs): | |
| score = doc.metadata.get('relevance_score', 'N/A') | |
| src = os.path.basename(doc.metadata.get('source', 'Unknown')) | |
| st.caption(f"Rank {i+1} (Source: {src}, Rel: {score})") | |
| st.text(doc.page_content) | |
| st.divider() | |
| # --- TAB 3: PROMPT ARCHITECT --- | |
| with tab3: | |
| st.header("π οΈ Mega-Prompt Factory") | |
| st.info("Build standard templates for NIPRGPT.") | |
| c1, c2 = st.columns([1,1]) | |
| with c1: | |
| st.subheader("1. Parameters") | |
| p = st.text_area("Persona", placeholder="Act as...", height=100) | |
| c = st.text_area("Context", placeholder="Background...", height=100) | |
| t = st.text_area("Task", placeholder="Action...", height=100) | |
| v = st.text_input("Placeholder Name", value="PASTE_DATA_HERE") | |
| with c2: | |
| st.subheader("2. Result") | |
| final = f"### ROLE\n{p}\n### CONTEXT\n{c}\n### TASK\n{t}\n### INPUT DATA\n\"\"\"\n[{v}]\n\"\"\"" | |
| st.code(final, language="markdown") | |
| st.download_button("πΎ Download .txt", final, "template.txt") | |
| # --- TAB 4: KNOWLEDGE BASE --- | |
| with tab4: | |
| st.header("π§ Personal Knowledge Base") | |
| st.info(f"Managing knowledge for: **{st.session_state.username}**") | |
| # We no longer check 'is_admin' for the whole tab | |
| kb_tab1, kb_tab2 = st.tabs(["π€ Add Documents", "ποΈ Manage Database"]) | |
| # --- SUB-TAB 1: UPLOAD (Unlocked for Everyone) --- | |
| with kb_tab1: | |
| st.subheader("Ingest New Knowledge") | |
| uploaded_file = st.file_uploader("Upload Instructions, Manuals, or Logs", type=["pdf", "docx", "txt", "md"]) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| chunk_strategy = st.selectbox( | |
| "Chunking Strategy", | |
| ["paragraph", "token", "page"], | |
| help="Paragraph: Manuals. Token: Dense text. Page: Forms." | |
| ) | |
| if uploaded_file and st.button("Process & Add"): | |
| with st.spinner("Analyzing and Indexing..."): | |
| # 1. Save temp file | |
| temp_path = rag_engine.save_uploaded_file(uploaded_file) | |
| # 2. Process locally | |
| success, msg = rag_engine.process_and_add_document( | |
| temp_path, | |
| st.session_state.username, | |
| chunk_strategy | |
| ) | |
| if success: | |
| # 3. FIX: SYNC TO CLOUD IMMEDIATELY | |
| with st.spinner("Backing up to Cloud..."): | |
| tracker.upload_user_db(st.session_state.username) | |
| st.success(msg) | |
| st.rerun() | |
| else: | |
| st.error(f"Failed: {msg}") | |
| st.divider() | |
| st.subheader("π Quick Test") | |
| test_query = st.text_input("Ask your brain something...") | |
| if test_query: | |
| results = rag_engine.search_knowledge_base(test_query, st.session_state.username) | |
| if not results: | |
| st.warning("No matches found.") | |
| for i, doc in enumerate(results): | |
| src_name = os.path.basename(doc.metadata.get('source', '?')) | |
| score = doc.metadata.get('relevance_score', 'N/A') | |
| with st.expander(f"Match {i+1}: {src_name} (Score: {score})"): | |
| st.write(doc.page_content) | |
| # --- SUB-TAB 2: MANAGE (Unlocked for Everyone) --- | |
| with kb_tab2: | |
| st.subheader("ποΈ Database Inventory") | |
| docs = rag_engine.list_documents(st.session_state.username) | |
| if not docs: | |
| st.info("Your Knowledge Base is empty.") | |
| else: | |
| st.markdown(f"**Total Documents:** {len(docs)}") | |
| for doc in docs: | |
| c1, c2, c3, c4 = st.columns([3, 2, 1, 1]) | |
| with c1: | |
| st.text(f"π {doc['filename']}") | |
| with c2: | |
| st.caption(f"βοΈ {doc.get('strategy', 'Unknown')}") | |
| with c3: | |
| st.caption(f"{doc['chunks']}") | |
| with c4: | |
| if st.button("ποΈ", key=doc['source'], help="Delete Document"): | |
| with st.spinner("Deleting..."): | |
| success, msg = rag_engine.delete_document(st.session_state.username, doc['source']) | |
| if success: | |
| tracker.upload_user_db(st.session_state.username) | |
| st.success(msg) | |
| st.rerun() | |
| else: | |
| st.error(msg) | |
| st.divider() | |
| with st.expander("π¨ Danger Zone"): | |
| # Allow ANY user to reset their OWN database | |
| if st.button("β’οΈ RESET MY DATABASE", type="primary"): | |
| success, msg = rag_engine.reset_knowledge_base(st.session_state.username) | |
| if success: | |
| st.success(msg) | |
| st.rerun() |