Shinui / src /streamlit_app.py
Pontonkid's picture
Update src/streamlit_app.py
37b828a verified
import streamlit as st
import os
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
# -----------------------------------------------------------------------------
# 0. AUTO-FIX FOR UPLOAD ERROR (RUNS INSTANTLY)
# -----------------------------------------------------------------------------
# This creates the config.toml automatically so uploads work.
config_dir = ".streamlit"
if not os.path.exists(config_dir):
os.makedirs(config_dir)
with open(os.path.join(config_dir, "config.toml"), "w") as f:
f.write("[server]\nenableXsrfProtection=false\nenableCORS=false\nmaxUploadSize=200\n")
# -----------------------------------------------------------------------------
# 1. SETUP & MODEL LOADING (AarambhAI Gemma)
# -----------------------------------------------------------------------------
st.set_page_config(page_title="SHINUI | Gemma AI", page_icon="✨", layout="wide")
@st.cache_resource
def load_model():
model_id = "AarambhAI/gemma-like-multimodal-speech-vision-text"
# Load Processor and Model
# We use trust_remote_code=True because this is a custom architecture
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32, # float32 is safer for CPU
device_map="auto",
trust_remote_code=True
)
return model, processor
# Load Model on App Start
try:
with st.spinner("Initializing Gemma Multimodal Model..."):
model, processor = load_model()
MODEL_LOADED = True
except Exception as e:
st.error(f"⚠️ Model Load Error: {e}")
MODEL_LOADED = False
# -----------------------------------------------------------------------------
# 2. STATE MANAGEMENT
# -----------------------------------------------------------------------------
if 'page' not in st.session_state: st.session_state.page = 'landing'
if 'logged_in' not in st.session_state: st.session_state.logged_in = False
if 'user_email' not in st.session_state: st.session_state.user_email = ""
if 'history' not in st.session_state: st.session_state.history = []
if 'result' not in st.session_state: st.session_state.result = None
# -----------------------------------------------------------------------------
# 3. THE BRAIN (Gemma Logic)
# -----------------------------------------------------------------------------
def get_gemma_insight(input_type, content):
if not MODEL_LOADED:
return "Error: Model not loaded."
try:
# A. VISION ANALYSIS
if input_type == "Image":
text_prompt = "Analyze this medical image and list observations."
# Gemma format input
inputs = processor(text=text_prompt, images=content, return_tensors="pt")
# Generate
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=200)
return processor.batch_decode(output, skip_special_tokens=True)[0]
# B. TEXT ANALYSIS
elif input_type == "Text":
text_prompt = f"Medical analysis for: {content}"
inputs = processor(text=text_prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=200)
return processor.batch_decode(output, skip_special_tokens=True)[0]
except Exception as e:
return f"⚠️ Processing Error: {str(e)}"
# -----------------------------------------------------------------------------
# 4. UI STYLING (Clean Dark Theme)
# -----------------------------------------------------------------------------
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@300;400;600;800&display=swap');
.stApp {
background-color: #020617;
background-image: radial-gradient(circle at 50% 0%, #1e293b 0%, #020617 70%);
font-family: 'Plus Jakarta Sans', sans-serif; color: #f8fafc;
}
.shinui-card {
background: rgba(30, 41, 59, 0.4); border: 1px solid rgba(148, 163, 184, 0.1);
border-radius: 16px; padding: 25px; backdrop-filter: blur(12px); margin-bottom: 20px;
}
div.stButton > button {
background: #38bdf8; color: #0f172a; border: none; font-weight: 700;
padding: 12px 20px; border-radius: 8px; width: 100%; transition: all 0.3s;
}
div.stButton > button:hover { background: #ffffff; box-shadow: 0 0 20px rgba(56, 189, 248, 0.5); }
#MainMenu, footer, header {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
# -----------------------------------------------------------------------------
# 5. NAVIGATION
# -----------------------------------------------------------------------------
def nav_to(page):
st.session_state.page = page
st.rerun()
def sign_out():
st.session_state.logged_in = False
st.session_state.history = []
st.session_state.result = None
st.session_state.user_email = ""
nav_to('landing')
# -----------------------------------------------------------------------------
# 6. PAGES
# -----------------------------------------------------------------------------
# --- LANDING ---
def show_landing():
c1, c2 = st.columns([1, 8])
with c1: st.markdown("### ✨ SHINUI")
st.markdown("<br><br>", unsafe_allow_html=True)
c1, c2 = st.columns([1.5, 1])
with c1:
st.markdown("""
<h1 style='font-size: 4rem; line-height: 1.1; margin-bottom: 20px;'>
Medical Intelligence.<br><span style='color:#38bdf8;'>Runs Locally.</span>
</h1>
<p style='font-size: 1.2rem; color: #94a3b8; margin-bottom: 40px;'>
SHINUI runs the specialized Gemma Multimodal model for secure analysis.
</p>
""", unsafe_allow_html=True)
b1, b2 = st.columns([1, 2])
with b1:
if st.button("Sign In"): nav_to('login')
with b2:
if st.button("About SHINUI"): nav_to('about')
with c2:
st.markdown("""
<div class='shinui-card'>
<h3>🧬 Gemma Multimodal</h3>
<p style='color:#94a3b8;'>Vision, Text & Speech capable.</p>
</div>
""", unsafe_allow_html=True)
# --- ABOUT ---
def show_about():
if st.button("← Back Home"): nav_to('landing')
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("""
<div class='shinui-card'>
<h2 style='color:#38bdf8'>About SHINUI</h2>
<p style='font-size:1.1rem; line-height:1.6'>
SHINUI utilizes the <b>AarambhAI Gemma-like Multimodal</b> model.
This model is unique because it understands images, text, and speech natively in a single architecture.
</p>
<hr style='border-color:#333'>
<h3>Capabilities</h3>
<ul>
<li><b>Visual Diagnostics:</b> Reads medical images.</li>
<li><b>Clinical Text:</b> Analyzes symptoms and notes.</li>
</ul>
</div>
""", unsafe_allow_html=True)
# --- LOGIN ---
def show_login():
c1, c2, c3 = st.columns([1,1,1])
with c2:
st.markdown("<br><br>", unsafe_allow_html=True)
st.markdown("<div class='shinui-card' style='text-align:center;'><h2>Member Access</h2></div>", unsafe_allow_html=True)
email = st.text_input("Email")
password = st.text_input("Password", type="password")
if st.button("Authenticate"):
if email:
st.session_state.logged_in = True
st.session_state.user_email = email
nav_to('dashboard')
if st.button("Back"): nav_to('landing')
# --- DASHBOARD ---
def show_dashboard():
with st.sidebar:
st.markdown(f"### 👤 {st.session_state.user_email}")
if st.button("About System"): nav_to('about_internal')
st.markdown("---")
st.write("HISTORY")
if st.session_state.history:
for h in reversed(st.session_state.history):
st.markdown(f"<div style='font-size:0.8rem; padding:5px; border-left:2px solid #38bdf8; margin-bottom:5px;'>{h[:50]}...</div>", unsafe_allow_html=True)
else:
st.caption("No scans yet.")
st.markdown("---")
if st.button("Sign Out"): sign_out()
st.title("Gemma Interface")
t1, t2 = st.tabs(["📷 Image Scan", "📝 Text Analysis"])
# TAB 1: IMAGE
with t1:
st.markdown("<div class='shinui-card'>", unsafe_allow_html=True)
img_file = st.file_uploader("Upload Medical Image", type=['png','jpg','jpeg'])
if img_file and st.button("Analyze Visual"):
if not MODEL_LOADED:
st.error("Model failed to load (Check Space Logs).")
else:
image = Image.open(img_file)
st.image(image, width=300)
with st.spinner("Gemma Processing..."):
res = get_gemma_insight("Image", image)
st.session_state.result = res
st.session_state.history.append(f"Image: {res[:30]}...")
st.markdown("</div>", unsafe_allow_html=True)
# TAB 2: TEXT
with t2:
st.markdown("<div class='shinui-card'>", unsafe_allow_html=True)
txt = st.text_area("Clinical Notes / Symptoms")
if txt and st.button("Analyze Notes"):
if not MODEL_LOADED:
st.error("Model failed to load.")
else:
with st.spinner("Gemma Processing..."):
res = get_gemma_insight("Text", txt)
st.session_state.result = res
st.session_state.history.append(f"Text: {res[:30]}...")
st.markdown("</div>", unsafe_allow_html=True)
if st.session_state.result:
st.markdown(f"""
<div class='shinui-card' style='border-left: 5px solid #38bdf8;'>
<h3 style='margin-top:0; color:#38bdf8;'>Analysis Result</h3>
<div style='white-space: pre-wrap; color: #e2e8f0; line-height: 1.6;'>{st.session_state.result}</div>
</div>
""", unsafe_allow_html=True)
# --- INTERNAL ABOUT ---
def show_about_internal():
with st.sidebar:
if st.button("← Back"): nav_to('dashboard')
st.markdown("""
<div class='shinui-card'>
<h2 style='color:#38bdf8'>System Status</h2>
<p><b>Model:</b> AarambhAI Gemma-like Multimodal</p>
<p><b>Backend:</b> Local Transformers</p>
</div>
""", unsafe_allow_html=True)
# -----------------------------------------------------------------------------
# 7. ROUTER
# -----------------------------------------------------------------------------
if st.session_state.page == 'landing': show_landing()
elif st.session_state.page == 'about': show_about()
elif st.session_state.page == 'login': show_login()
elif st.session_state.page == 'dashboard':
if st.session_state.logged_in: show_dashboard()
else: nav_to('login')
elif st.session_state.page == 'about_internal':
if st.session_state.logged_in: show_about_internal()
else: nav_to('login')