dabbitz's picture
Update app.py
e474ce6 verified
import streamlit as st
import torch
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM, # For translation
pipeline
)
# --- CONFIGURATION ---
MODEL_DIR = "symptom_classifier"
# --- SYMPTOM TRANSLATION DICTIONARY ---
SYMPTOM_TR_MAP = {
"headache": "Baş ağrısı",
"dizziness": "Baş dönmesi",
"fatigue": "Yorgunluk",
"blurred_vision": "Bulanık görme",
"runny_nose": "Burun akıntısı",
"blocked_nose": "Burun tıkanıklığı",
"itchy_nose": "Burun kaşıntısı",
"shortness_of_breath": "Nefes darlığı",
"coughing": "Öksürük",
"sneezing": "Hapşırma",
"pressure_in_ears": "Kulaklarda basınç",
"pressure_in_face": "Yüzde basınç",
"low_fever": "Hafif ateş",
"high_fever": "Yüksek ateş",
"diarrhea": "İshal",
"nausea": "Bulantı",
"vomiting": "Kusma",
"stomach_cramps": "Karın krampları",
"stomach_pain": "Karın ağrısı",
"occasional_muscle_ache": "Ara sıra oluşan kas ağrısı",
"sore_throat": "Boğaz tahrişi",
"chest_tightness": "Göğüs sıkışması",
"itchy_eyes": "Göz kaşıntısı",
"watery_eyes": "Göz sulanması",
"rashes": "Döküntüler",
"swelling_of_skin": "Ciltte şişme",
"redness_of_skin": "Ciltte kızarıklık",
"throat_pain": "Boğaz ağrısı",
"itchy_skin": "Cilt kaşıntısı",
"dry_skin": "Cilt kuruluğu",
"blisters": "Kabarcıklar",
"dandruff": "Kepek",
"thickened_skin": "Cilt kalınlaşması",
"raised_bumps_on_skin": "Ciltte kabarıklıklar",
"blotchy_skin": "Lekeli cilt",
"redness_in_eyes": "Gözlerde kızarıklık",
"sensitivity_to_light": "Işığa duyarlılık",
"swelling_eyes": "Göz şişmesi",
"double_vision": "Çift görme",
"seeing_faded_colors": "Renkleri soluk görme",
"difficulty_seeing_in_the_dark": "Karanlıkta görme zorluğu",
"constant_weight_loss": "Sürekli kilo kaybı",
"pain_or_numbness_at_hands_and_feet": "El ve ayaklarda ağrı veya uyuşma",
"frequent_urination": "Sık idrara çıkma",
"slow_healing_of_cuts_and_bruises": "Kesik ve morlukların geç iyileşmesi",
"constant_hunger": "Sürekli açlık",
"burning_skin": "Ciltte yanma hissi",
"constipation": "Kabızlık",
"low_back_pain": "Bel ağrısı",
"ear_pain": "Kulak ağrısı",
"hearing_loss": "İşitme kaybı",
"melancholy": "Melankoli / Karamsarlık",
"loss_of_appetite": "İştah kaybı",
"lack_of_sleep": "Uykusuzluk",
"palpitation": "Çarpıntı",
"pale_skin": "Solgun cilt",
"pain_during_urination": "İdrar yaparken ağrı",
"groin_pain": "Kasık ağrısı",
"no_symptom": "Belirgin semptom yok"
}
@st.cache_resource
def load_models():
# 1. Load model (English Classifier)
with open(f"{MODEL_DIR}/symptom_list.txt", "r") as f:
symptoms = [line.strip() for line in f]
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
# 2. Load Whisper for Speech-to-Text
stt_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
# 3. Load Translation Model and Tokenizer Directly
tr_model_name = "Helsinki-NLP/opus-mt-tr-en"
tr_tokenizer = AutoTokenizer.from_pretrained(tr_model_name)
tr_model = AutoModelForSeq2SeqLM.from_pretrained(tr_model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
tr_model.to(device)
model.eval()
tr_model.eval()
return tokenizer, model, symptoms, device, stt_pipe, tr_tokenizer, tr_model
tokenizer, model, symptoms, device, stt_pipe, tr_tokenizer, tr_model = load_models()
st.title("AI Symptom Classifier (TR/EN)")
st.markdown("Semptomlarınızı **Türkçe** veya **İngilizce** açıklayınız. Girdiniz otomatik olarak işlenecektir.")
# --- INPUT SECTION ---
tab1, tab2 = st.tabs(["Metin Girdisi (Text)", "Ses Girdisi (Voice)"])
input_text = ""
with tab1:
text_msg = st.text_area("Semptomlarınızı açıklayın / Describe symptoms:",
placeholder="Örn: Başım ağrıyor ve ateşim var...")
if text_msg:
input_text = text_msg
with tab2:
audio_file = st.audio_input("Sesinizi kaydedin / Record voice")
if audio_file:
with st.spinner("Ses çözülüyor..."):
# Whisper automatically detects Turkish or English
stt_result = stt_pipe(audio_file.read())
input_text = stt_result["text"]
st.info(f"Algılanan Metin: \"{input_text}\"")
# --- ANALYSIS SECTION ---
# Default threshold set to 0.70
threshold = st.sidebar.slider("Sensitivity Threshold", 0.0, 1.0, 0.70)
if input_text:
if st.button("Semptomları Analiz Et / Analyze"):
with st.spinner("Processing..."):
# 1. Translation Logic (Direct Model Usage)
inputs = tr_tokenizer(input_text, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
translated_tokens = tr_model.generate(**inputs)
translated_text = tr_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
st.subheader("Translation")
st.info(f"**English version:** {translated_text}")
# 2. Inference logic (using translated text)
encoding = tokenizer(translated_text, return_tensors="pt", truncation=True, padding=True)
encoding = {k: v.to(device) for k, v in encoding.items()}
with torch.no_grad():
logits = model(**encoding).logits
probs = torch.sigmoid(logits).cpu().numpy()[0]
# 3. Results processing
results = {symptoms[i]: float(p) for i, p in enumerate(probs)}
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
col1, col2 = st.columns(2)
with col1:
st.subheader("Tespit Edilen Semptomlar")
detected = [s for s, p in sorted_results if p > threshold]
if not detected:
st.warning("Belirgin bir semptom saptanmadı.")
else:
for d in detected:
# Fetch translation from dictionary, fallback to English if not found
tr_name = SYMPTOM_TR_MAP.get(d, d)
st.success(f"✅ **{tr_name}**")
with col2:
st.subheader("Güven Skorları")
for sym, p in sorted_results[:8]:
tr_name = SYMPTOM_TR_MAP.get(sym, sym)
st.write(f"**{tr_name}** ({sym})")
st.progress(p)
st.caption(f"%{p*100:.1f} güven oranı")