KB-Infinity-Tech's picture
Update src/streamlit_app.py
9becdce verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from PIL import Image
import numpy as np
from gtts import gTTS
import tempfile
import re
# ----------------------------
# MODELS
# ----------------------------
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@st.cache_resource
def load_llm():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="cpu"
)
return tokenizer, model
@st.cache_resource
def load_asr():
return pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
tokenizer, model = load_llm()
asr = load_asr()
# ----------------------------
# MULTILINGUAL DETECTION
# ----------------------------
LANG_WORDS = {
"en": ["one", "two", "three", "four", "five"],
"fr": ["un", "deux", "trois", "quatre", "cinq"],
"sw": ["moja", "mbili", "tatu", "nne", "tano"],
"kin": ["imwe", "ebyiri", "eshatu", "enye", "eshanu"]
}
def detect_mixed_language(text):
text = text.lower()
scores = {lang: 0 for lang in LANG_WORDS}
for lang, words in LANG_WORDS.items():
for w in words:
if w in text:
scores[lang] += 1
dominant = max(scores, key=scores.get)
# detect mix
active_langs = [l for l, s in scores.items() if s > 0]
if len(active_langs) > 1:
return dominant, active_langs
else:
return dominant, [dominant]
# ----------------------------
# PROMPT ENGINEERING
# ----------------------------
def build_prompt(user_input, dominant_lang, langs_used):
if dominant_lang == "fr":
base = "Tu es un tuteur de mathรฉmatiques pour enfants. Explique simplement."
elif dominant_lang == "sw":
base = "Wewe ni mwalimu wa hesabu kwa watoto. Eleza kwa urahisi."
elif dominant_lang == "kin":
base = "Uri umwarimu w'imibare ku bana. Sobanura neza."
else:
base = "You are a friendly math tutor for kids. Explain step by step."
# Handle code-switch
if len(langs_used) > 1:
base += " The child used mixed languages. Keep explanation in main language but reuse number words from other language."
return f"{base}\nUser: {user_input}\nAssistant:"
# ----------------------------
# GENERATION
# ----------------------------
def generate(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(
**inputs,
max_new_tokens=80,
temperature=0.7,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# ----------------------------
# TTS (MULTILINGUAL)
# ----------------------------
def speak(text, lang="en"):
lang_map = {
"en": "en",
"fr": "fr",
"sw": "sw",
"kin": "en" # fallback
}
tts = gTTS(text=text, lang=lang_map.get(lang, "en"))
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
return fp.name
# ----------------------------
# VISUAL COUNTING (Baseline)
# ----------------------------
def count_objects(image):
img = np.array(image.convert("L"))
binary = img > 128
count = int(binary.sum() / 400)
return max(1, count)
# ----------------------------
# UI
# ----------------------------
st.set_page_config(layout="wide")
st.title("๐Ÿง ๐ŸŒ Multilingual AI Math Tutor")
col1, col2 = st.columns(2)
# ----------------------------
# LEFT PANEL
# ----------------------------
with col1:
st.header("๐Ÿ‘ง Student Interaction")
mode = st.radio("Mode", ["Text", "Voice", "Image"])
# -------- TEXT --------
if mode == "Text":
user_input = st.text_input("Ask or answer:")
if user_input:
dominant, langs = detect_mixed_language(user_input)
prompt = build_prompt(user_input, dominant, langs)
response = generate(prompt)
st.write("### ๐Ÿ“˜ Answer")
st.write(response)
st.write(f"๐ŸŒ Dominant: {dominant} | Mixed: {langs}")
if st.button("๐Ÿ”Š Speak"):
audio = speak(response, dominant)
st.audio(audio)
# -------- VOICE --------
elif mode == "Voice":
audio_file = st.file_uploader("Upload voice (.wav)", type=["wav", "mp3"])
if audio_file:
result = asr(audio_file)
text = result["text"]
st.write(f"๐Ÿ—ฃ๏ธ Detected: {text}")
dominant, langs = detect_mixed_language(text)
prompt = build_prompt(text, dominant, langs)
response = generate(prompt)
st.write("### ๐ŸŽง Response")
st.write(response)
audio = speak(response, dominant)
st.audio(audio)
# -------- IMAGE --------
elif mode == "Image":
uploaded = st.file_uploader("Upload image", type=["png", "jpg"])
if uploaded:
image = Image.open(uploaded)
st.image(image)
count = count_objects(image)
st.write(f"### ๐Ÿงฎ I see about {count} objects")
explanation = f"There are {count} objects. Let's count together."
audio = speak(explanation)
st.audio(audio)
# ----------------------------
# RIGHT PANEL (DASHBOARD)
# ----------------------------
with col2:
st.header("๐Ÿ“Š Learning Dashboard")
st.metric("Questions", 15)
st.metric("Accuracy", "80%")
st.metric("Level", "Improving")
st.subheader("๐Ÿ“ˆ Skill Progress")
st.progress(0.8)
st.subheader("๐ŸŒ Language System")
st.write("โœ” English / French / Swahili / Kinyarwanda")
st.write("โœ” Code-switch detection")
st.subheader("โšก Features")
st.write("โœ” Voice (Whisper)")
st.write("โœ” Visual counting")
st.write("โœ” Multimodal learning")