import streamlit as st
import os
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
# -----------------------------------------------------------------------------
# 0. AUTO-FIX FOR UPLOAD ERROR (RUNS INSTANTLY)
# -----------------------------------------------------------------------------
# This creates the config.toml automatically so uploads work.
config_dir = ".streamlit"
if not os.path.exists(config_dir):
os.makedirs(config_dir)
with open(os.path.join(config_dir, "config.toml"), "w") as f:
f.write("[server]\nenableXsrfProtection=false\nenableCORS=false\nmaxUploadSize=200\n")
# -----------------------------------------------------------------------------
# 1. SETUP & MODEL LOADING (AarambhAI Gemma)
# -----------------------------------------------------------------------------
st.set_page_config(page_title="SHINUI | Gemma AI", page_icon="✨", layout="wide")
@st.cache_resource
def load_model():
model_id = "AarambhAI/gemma-like-multimodal-speech-vision-text"
# Load Processor and Model
# We use trust_remote_code=True because this is a custom architecture
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32, # float32 is safer for CPU
device_map="auto",
trust_remote_code=True
)
return model, processor
# Load Model on App Start
try:
with st.spinner("Initializing Gemma Multimodal Model..."):
model, processor = load_model()
MODEL_LOADED = True
except Exception as e:
st.error(f"⚠️ Model Load Error: {e}")
MODEL_LOADED = False
# -----------------------------------------------------------------------------
# 2. STATE MANAGEMENT
# -----------------------------------------------------------------------------
if 'page' not in st.session_state: st.session_state.page = 'landing'
if 'logged_in' not in st.session_state: st.session_state.logged_in = False
if 'user_email' not in st.session_state: st.session_state.user_email = ""
if 'history' not in st.session_state: st.session_state.history = []
if 'result' not in st.session_state: st.session_state.result = None
# -----------------------------------------------------------------------------
# 3. THE BRAIN (Gemma Logic)
# -----------------------------------------------------------------------------
def get_gemma_insight(input_type, content):
if not MODEL_LOADED:
return "Error: Model not loaded."
try:
# A. VISION ANALYSIS
if input_type == "Image":
text_prompt = "Analyze this medical image and list observations."
# Gemma format input
inputs = processor(text=text_prompt, images=content, return_tensors="pt")
# Generate
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=200)
return processor.batch_decode(output, skip_special_tokens=True)[0]
# B. TEXT ANALYSIS
elif input_type == "Text":
text_prompt = f"Medical analysis for: {content}"
inputs = processor(text=text_prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=200)
return processor.batch_decode(output, skip_special_tokens=True)[0]
except Exception as e:
return f"⚠️ Processing Error: {str(e)}"
# -----------------------------------------------------------------------------
# 4. UI STYLING (Clean Dark Theme)
# -----------------------------------------------------------------------------
st.markdown("""
""", unsafe_allow_html=True)
# -----------------------------------------------------------------------------
# 5. NAVIGATION
# -----------------------------------------------------------------------------
def nav_to(page):
st.session_state.page = page
st.rerun()
def sign_out():
st.session_state.logged_in = False
st.session_state.history = []
st.session_state.result = None
st.session_state.user_email = ""
nav_to('landing')
# -----------------------------------------------------------------------------
# 6. PAGES
# -----------------------------------------------------------------------------
# --- LANDING ---
def show_landing():
c1, c2 = st.columns([1, 8])
with c1: st.markdown("### ✨ SHINUI")
st.markdown("
", unsafe_allow_html=True)
c1, c2 = st.columns([1.5, 1])
with c1:
st.markdown("""
SHINUI runs the specialized Gemma Multimodal model for secure analysis.
""", unsafe_allow_html=True) b1, b2 = st.columns([1, 2]) with b1: if st.button("Sign In"): nav_to('login') with b2: if st.button("About SHINUI"): nav_to('about') with c2: st.markdown("""Vision, Text & Speech capable.
SHINUI utilizes the AarambhAI Gemma-like Multimodal model. This model is unique because it understands images, text, and speech natively in a single architecture.
Model: AarambhAI Gemma-like Multimodal
Backend: Local Transformers