Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from streamlit_chat import message | |
| from model.utils import create_model, setup_logger, generate_content | |
| import google.generativeai as genai | |
| import base64 | |
| from settings.base import setup_logger, GOOGLE_API_KEY | |
| from user_auth.user_manager import UserManager | |
| from user_auth.auth_manager import AuthManager | |
| # Set up logging | |
| logger = setup_logger() | |
| # Initialize session state variables | |
| if 'history' not in st.session_state: | |
| st.session_state.history = [] | |
| if 'chat_session' not in st.session_state: | |
| st.session_state.chat_session = None | |
| if 'authenticated' not in st.session_state: | |
| st.session_state.authenticated = False | |
| if 'email' not in st.session_state: | |
| st.session_state.email = "" | |
| if 'profile' not in st.session_state: | |
| st.session_state.profile = None | |
| if 'model' not in st.session_state: | |
| st.session_state.model = None | |
| def authenticate_user(email, password): | |
| """ | |
| Authenticate user using AuthManager. | |
| """ | |
| user_manager = UserManager() | |
| auth_manager = AuthManager(user_manager) | |
| authenticated = auth_manager.authenticate_user(email, password) | |
| return authenticated | |
| def conversation_chat(file_path, text_prompt): | |
| """ | |
| Handle the conversation with the chat session using the provided file path and text prompt. | |
| """ | |
| temp_dir = "temp" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| try: | |
| user_input = text_prompt if text_prompt else "" | |
| if file_path: | |
| temp_file_path = os.path.join(temp_dir, file_path.name) | |
| with open(temp_file_path, "wb") as f: | |
| f.write(file_path.read()) # Write the uploaded file to temp location | |
| logger.info(f"Successfully saved uploaded file: {file_path.name}") | |
| file = genai.upload_file(path=temp_file_path, display_name=file_path.name) | |
| user_entry = { | |
| "role": "user", | |
| "parts": [file, user_input] | |
| } | |
| else: | |
| user_entry = { | |
| "role": "user", | |
| "parts": [user_input] | |
| } | |
| st.session_state.history.append(user_entry) | |
| response_text = generate_content(st.session_state.model, st.session_state.history) | |
| bot_entry = { | |
| "role": "model", | |
| "parts": response_text | |
| } | |
| st.session_state.history.append(bot_entry) | |
| logger.info("Conversation successfully processed") | |
| except Exception as e: | |
| logger.error(f"Error processing file: {e}") | |
| st.session_state.history.append("Error") | |
| def display_chat(): | |
| """ | |
| Display chat input and responses. | |
| """ | |
| profile = st.session_state.profile | |
| st.title("Wellness Bot 🦾🤖") | |
| chat_container = st.container() | |
| upload_container = st.container() | |
| clear_chat_button = st.button('Clear Chat') | |
| if st.button('Logout'): | |
| st.session_state.authenticated = False | |
| st.session_state.email = "" | |
| st.session_state.history = [] | |
| st.session_state.chat_session = None | |
| st.session_state.profile = None | |
| st.session_state.model = None | |
| st.rerun() | |
| with upload_container: | |
| with st.form(key='chat_form', clear_on_submit=True): | |
| file_path = st.file_uploader("Upload an image or audio of a meal:", type=["jpg", "jpeg", "png", "mpeg", "mp3", "wav", "ogg", "mp4"]) | |
| text_prompt = st.text_input("Type Here...") | |
| submit_button = st.form_submit_button(label='Send ⬆️') | |
| if submit_button: | |
| conversation_chat(file_path, text_prompt) | |
| if clear_chat_button: | |
| st.session_state.history = [] | |
| st.session_state.chat_session = None | |
| with chat_container: | |
| message(f"Hey {profile['name']}! I'm here to assist you with your meals. Let's make healthy eating a breeze. Feel free to upload an image/video/audio of your meal or ask any questions about nutrition, dietary needs, and meal planning. Together, we'll achieve your goals and ensure you stay healthy and happy.", avatar_style="bottts") | |
| for i, entry in enumerate(st.session_state.history): | |
| if entry['role'] == 'user': | |
| if len(entry['parts']) > 1: | |
| uploaded_file = entry['parts'][0] | |
| if uploaded_file.mime_type.startswith("image/"): | |
| file_name= uploaded_file.display_name | |
| with open(os.path.join("temp", file_name), "rb") as file: | |
| encoded_img = base64.b64encode(file.read()).decode('utf-8') | |
| img_html = f'<img src="data:image/png;base64,{encoded_img}" width="200" style="margin-top: 5px;"/>' | |
| st.markdown(f""" | |
| <div style="display: flex; justify-content: flex-end; margin-bottom: 10px;"> | |
| <div style="max-width: 300px; background-color: #E8E8E8; border-radius: 10px; padding: 10px; position: relative;"> | |
| <div style="text-align: right;"> | |
| {img_html} | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| message(uploaded_file.display_name, is_user=True, key=f"{i}_user", avatar_style="adventurer") | |
| if entry['parts'][1] != "": | |
| message(entry['parts'][1], is_user=True, key=f"{i}_user", avatar_style="adventurer") | |
| else: | |
| message(entry['parts'][0], is_user=True, key=f"{i}_user", avatar_style="adventurer") | |
| elif entry['role'] == 'model': | |
| message(entry['parts'], key=str(i), avatar_style="bottts") | |
| def main(): | |
| """ | |
| Main function to run the Streamlit app. | |
| """ | |
| user_manager = UserManager() | |
| auth_manager = AuthManager(user_manager) | |
| st.set_page_config(page_title="Wellness Bot 🦾🤖", page_icon=":fork_and_knife:") | |
| if st.session_state.authenticated: | |
| if st.session_state.profile is None or st.session_state.model is None: | |
| st.session_state.profile = user_manager.get_user(st.session_state.email).profile | |
| st.session_state.model = create_model(GOOGLE_API_KEY, st.session_state.profile) | |
| display_chat() | |
| else: | |
| st.title("Wellness Bot 🦾🤖 \n\nLogin/SignUp") | |
| tab1, tab2 = st.tabs(["Login", "Sign Up"]) | |
| with tab1: | |
| st.header("Login") | |
| email = st.text_input("Email", key="login_email") | |
| password = st.text_input("Password", type="password", key="login_password") | |
| if st.button("Login"): | |
| if auth_manager.authenticate_user(email, password): | |
| st.success("Login successful!") | |
| st.session_state.authenticated = True | |
| st.session_state.email = email | |
| st.rerun() | |
| else: | |
| st.error("Invalid email or password") | |
| with tab2: | |
| st.header("Create Account") | |
| email = st.text_input("Email", key="signup_email") | |
| password = st.text_input("Password", type="password", key="signup_password") | |
| confirm_password = st.text_input("Confirm Password", type="password", key="signup_confirm_password") | |
| name = st.text_input("Name") | |
| age = st.number_input("Age", min_value=1, max_value=120) | |
| gender = st.selectbox("Gender", ["Male", "Female", "Other"]) | |
| height = st.number_input("Height (cm)", min_value=30, max_value=300) | |
| weight = st.number_input("Weight (kg)", min_value=1, max_value=500) | |
| location = st.text_input("Location") | |
| allergies = st.text_input("Allergies (comma-separated)") | |
| spec_diet_pref = st.text_input("Special Dietary Preferences") | |
| primary_goal = st.text_area("Primary Goal") | |
| health_condition = st.text_area("Health Condition") | |
| activity_level = st.selectbox("Activity Level", ["Sedentary", "Lightly Active", "Moderately Active", "Very Active"]) | |
| daily_calorie_intake = st.number_input("Daily Calorie Intake", min_value=1, max_value=10000) | |
| if st.button("Sign Up"): | |
| if password != confirm_password: | |
| st.error("Passwords do not match") | |
| else: | |
| profile = { | |
| "name": name, | |
| "age": age, | |
| "gender": gender, | |
| "height": height, | |
| "weight": weight, | |
| "location": location, | |
| "allergies": allergies.split(','), | |
| "spec_diet_pref": spec_diet_pref, | |
| "primary_goal": primary_goal, | |
| "health_condition": health_condition, | |
| "activity_level": activity_level, | |
| "daily_calorie_intake": daily_calorie_intake | |
| } | |
| success, message = user_manager.create_user(email, password, profile) | |
| if success: | |
| st.success(message) | |
| else: | |
| st.error(message) | |
| if __name__ == "__main__": | |
| main() | |