import os, io, base64, json
from datetime import datetime
import streamlit as st
from typing import List, Dict, Any
from db import init_db, get_user_by_username, create_conversation, list_conversations, rename_conversation, delete_conversation, add_message, get_messages, list_users, set_user_role, set_user_active, update_user_password
from auth import hash_password, verify_password, ensure_admin
from providers import OpenAIProvider, OllamaProvider, ProviderError
st.set_page_config(page_title="ChatAI (Streamlit)", page_icon="💬", layout="wide")
# --- CSS for rounded chat bubbles & '+' floating button ---
custom_css = """
"""
st.markdown(custom_css, unsafe_allow_html=True)
# --- Initialize DB and default admin ---
init_db()
created_admin, admin_pwd = ensure_admin()
# --- Session State ---
if "user" not in st.session_state:
st.session_state.user = None
if "conversation_id" not in st.session_state:
st.session_state.conversation_id = None
if "show_uploader" not in st.session_state:
st.session_state.show_uploader = False
if "messages_cache" not in st.session_state:
st.session_state.messages_cache = []
def do_login(username, password):
user = get_user_by_username(username)
if not user or not user["is_active"]:
st.error("Sai tài khoản hoặc tài khoản đang bị khóa.")
return False
if verify_password(password, user["password_hash"]):
st.session_state.user = {"id": user["id"], "username": user["username"], "role": user["role"]}
return True
st.error("Mật khẩu không đúng.")
return False
def logout():
st.session_state.user = None
st.session_state.conversation_id = None
st.session_state.messages_cache = []
# --- Sidebar: Auth & Settings ---
with st.sidebar:
st.header("⚙️ Cấu hình")
if st.session_state.user:
st.success(f"Đã đăng nhập: **{st.session_state.user['username']}** ({st.session_state.user['role']})")
if st.button("Đăng xuất"):
logout()
st.rerun()
else:
st.subheader("Đăng nhập")
with st.form("login_form", clear_on_submit=False):
u = st.text_input("Tên đăng nhập")
p = st.text_input("Mật khẩu", type="password")
submitted = st.form_submit_button("Đăng nhập")
if submitted:
if do_login(u, p):
st.rerun()
st.divider()
st.subheader("Nhà cung cấp AI")
provider = st.selectbox("Provider", ["OpenAI", "Ollama"])
if provider == "OpenAI":
openai_key = st.text_input("OpenAI API Key", type="password", value=os.environ.get("OPENAI_API_KEY", ""))
model = st.text_input("Model", value="gpt-4o-mini")
else:
ollama_url = st.text_input("Ollama Endpoint", value="http://localhost:11434")
model = st.text_input("Model", value="llama3.1:8b")
temperature = st.slider("Temperature", 0.0, 1.0, 0.3, 0.05)
st.divider()
st.subheader("Tùy chọn")
sys_prompt = st.text_area("System Prompt (tùy chọn)", value="You are a helpful assistant. Answer in Vietnamese if the user speaks Vietnamese.")
if created_admin:
st.info(f"Admin mặc định đã được tạo. Tài khoản: admin / Mật khẩu: {admin_pwd} — hãy đổi ngay!")
# --- Page Router ---
def page_chat():
st.title("💬 ChatAI")
if not st.session_state.user:
st.info("Hãy đăng nhập để bắt đầu trò chuyện.")
return
# Conversations list
left, right = st.columns([1, 3])
with left:
st.subheader("🗂 Cuộc trò chuyện")
if st.button("➕ Tạo cuộc trò chuyện mới"):
st.session_state.conversation_id = create_conversation(st.session_state.user["id"], title="New Chat")
st.session_state.messages_cache = []
st.rerun()
convs = list_conversations(st.session_state.user["id"])
for c in convs:
selected = st.button(f"🗨 {c['title']}", key=f"conv_{c['id']}")
if selected:
st.session_state.conversation_id = c["id"]
st.session_state.messages_cache = get_messages(c["id"])
st.rerun()
with right:
if not st.session_state.conversation_id:
st.info("Chưa có cuộc trò chuyện. Hãy tạo mới bên trái.")
return
# Rename / Delete
cc1, cc2 = st.columns([3,1])
with cc1:
new_title = st.text_input("Tên cuộc trò chuyện", value="")
if st.button("Đổi tên"):
if new_title.strip():
rename_conversation(st.session_state.conversation_id, new_title.strip())
st.success("Đã đổi tên.")
else:
st.warning("Tên không hợp lệ.")
with cc2:
if st.button("🗑 Xóa cuộc trò chuyện", type="secondary"):
delete_conversation(st.session_state.conversation_id)
st.session_state.conversation_id = None
st.session_state.messages_cache = []
st.rerun()
# Chat history UI
msgs = st.session_state.messages_cache or get_messages(st.session_state.conversation_id)
for m in msgs:
role = m["role"]
content = m["content"]
with st.chat_message("assistant" if role=="assistant" else "user"):
st.markdown(content)
try:
atts = json.loads(m.get("attachments") or "[]")
for a in atts:
st.caption(f"📎 {a.get('name','file')} ({a.get('type','file')})")
except Exception:
pass
# Chat input
user_msg = st.chat_input("Nhập tin nhắn...")
# Floating '+' button for uploads
st.markdown('', unsafe_allow_html=True)
# Fallback toggle
if st.button("Hiện/Tắt upload (fallback)"):
st.session_state.show_uploader = not st.session_state.show_uploader
if st.session_state.show_uploader:
with st.container():
st.markdown('
', unsafe_allow_html=True)
st.write("**Tải lên để đính kèm**")
file_uploader = st.file_uploader("Tệp (txt, pdf)", type=["txt","pdf"], accept_multiple_files=True, key="file_up")
img_uploader = st.file_uploader("Ảnh", type=["png","jpg","jpeg","webp"], accept_multiple_files=True, key="img_up")
if st.button("Đóng"):
st.session_state.show_uploader = False
st.markdown('
', unsafe_allow_html=True)
# Process send
if user_msg or (st.session_state.get("file_up") or st.session_state.get("img_up")):
attachments = []
def extract_text_from_file(f):
name = f.name
if name.lower().endswith(".txt"):
return f.read().decode("utf-8", errors="ignore")
if name.lower().endswith(".pdf"):
try:
import PyPDF2
reader = PyPDF2.PdfReader(io.BytesIO(f.read()))
pages = []
for p in reader.pages:
pages.append(p.extract_text() or "")
return "\\n".join(pages)
except Exception as e:
return f"[Không thể trích xuất PDF: {e}]"
return ""
uploaded_files = st.session_state.get("file_up") or []
uploaded_imgs = st.session_state.get("img_up") or []
context_snippets = []
for f in uploaded_files:
text = extract_text_from_file(f)
attachments.append({"name": f.name, "type": "file", "size": f.size})
if text:
context_snippets.append(f"### {f.name}\\n{text[:6000]}")
image_refs = []
for img in uploaded_imgs:
b64 = base64.b64encode(img.read()).decode("utf-8")
mime = "image/png" if img.name.lower().endswith("png") else "image/jpeg"
image_refs.append({"name": img.name, "type": "image", "b64": b64, "mime": mime})
attachments.append({"name": img.name, "type": "image", "size": img.size})
db_msgs = get_messages(st.session_state.conversation_id)
chat_history = [{"role": m["role"], "content": m["content"]} for m in db_msgs]
system_preamble = sys_prompt.strip() if sys_prompt else ""
if context_snippets:
system_preamble += "\\n\\n# File context (tóm tắt)\\n" + "\\n\\n".join(context_snippets)
messages_for_provider: List[Dict[str, Any]] = []
if system_preamble:
messages_for_provider.append({"role": "system", "content": system_preamble})
messages_for_provider.extend(chat_history[-12:])
if user_msg:
messages_for_provider.append({"role": "user", "content": user_msg})
try:
if provider == "OpenAI":
p = OpenAIProvider(api_key=openai_key if openai_key else None)
resp = p.generate(messages=messages_for_provider, model=model, temperature=temperature)
else:
p = OllamaProvider(base_url=ollama_url)
resp = p.generate(messages=messages_for_provider, model=model, temperature=temperature)
except ProviderError as e:
resp = f"Lỗi nhà cung cấp: {e}"
except Exception as e:
resp = f"Lỗi không xác định: {e}"
if user_msg:
add_message(st.session_state.conversation_id, "user", user_msg, attachments=json.dumps(attachments))
st.session_state.messages_cache.append({"role":"user","content":user_msg,"attachments":json.dumps(attachments)})
add_message(st.session_state.conversation_id, "assistant", resp)
st.session_state.messages_cache.append({"role":"assistant","content":resp,"attachments":"[]"})
st.session_state.show_uploader = False
st.rerun()
# Export
st.divider()
if st.button("⬇️ Xuất cuộc trò chuyện (Markdown)"):
msgs = get_messages(st.session_state.conversation_id)
md = ["# Lịch sử trò chuyện"]
for m in msgs:
who = "👤 User" if m["role"]=="user" else "🤖 Assistant"
md.append(f"**{who}**\\n\\n{m['content']}\\n")
b = "\\n\\n---\\n\\n".join(md).encode("utf-8")
st.download_button("Tải về .md", data=b, file_name="chat_history.md", mime="text/markdown")
def page_admin():
if not st.session_state.user or st.session_state.user["role"] != "admin":
st.warning("Chỉ Admin mới truy cập được trang này.")
return
st.title("🛠 Admin Panel")
st.subheader("Quản lý người dùng")
users = list_users()
for u in users:
cols = st.columns([2,1,1,1,2])
cols[0].write(f"**{u['username']}**")
cols[1].write(u["role"])
cols[2].write("✅" if u["is_active"] else "🚫")
if cols[3].button("Đổi vai trò", key=f"role_{u['id']}"):
new_role = "admin" if u["role"]=="user" else "user"
set_user_role(u["id"], new_role)
st.rerun()
with cols[4]:
c1, c2, c3 = st.columns(3)
if c1.button("Khóa/Mở", key=f"toggle_{u['id']}"):
set_user_active(u["id"], not u["is_active"])
st.rerun()
if c2.button("Đặt lại MK", key=f"reset_{u['id']}"):
newpw = st.text_input(f"Mật khẩu mới cho {u['username']}", key=f"pw_{u['id']}")
if newpw:
update_user_password(u["id"], hash_password(newpw))
st.success("Đã cập nhật mật khẩu.")
st.rerun()
if c3.button("Xóa", key=f"del_{u['id']}"):
from db import delete_user as delu
if u["username"]=="admin":
st.error("Không được xóa tài khoản admin gốc.")
else:
delu(u["id"])
st.rerun()
st.divider()
st.subheader("Thêm người dùng mới")
with st.form("new_user_form"):
nu = st.text_input("Tên đăng nhập")
np = st.text_input("Mật khẩu", type="password")
role = st.selectbox("Vai trò", ["user", "admin"])
submitted = st.form_submit_button("Tạo")
if submitted:
if not nu or not np:
st.error("Thiếu thông tin.")
elif get_user_by_username(nu):
st.error("Tên đăng nhập đã tồn tại.")
else:
from db import create_user
create_user(nu, hash_password(np), role=role, is_active=True)
st.success("Đã tạo người dùng.")
st.rerun()
# --- Main ---
tab = st.tabs(["💬 Chat", "🛠 Admin"])
with tab[0]:
page_chat()
with tab[1]:
page_admin()
# Small JS to listen to postMessage and toggle uploader (best-effort, uses the fallback button)
st.markdown(\"""
\""", unsafe_allow_html=True)