|
|
import torch |
|
|
import random |
|
|
import numpy as np |
|
|
torch.manual_seed(42) |
|
|
random.seed(42) |
|
|
np.random.seed(42) |
|
|
import streamlit as st |
|
|
import io |
|
|
from PIL import Image |
|
|
import os |
|
|
from transformers import logging |
|
|
from SkinGPT import SkinGPTClassifier |
|
|
from fpdf import FPDF |
|
|
import nest_asyncio |
|
|
nest_asyncio.apply() |
|
|
torch.set_default_dtype(torch.float32) |
|
|
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
logging.set_verbosity_error() |
|
|
token = os.getenv("HF_TOKEN") |
|
|
if not token: |
|
|
raise ValueError("Hugging Face token not found in environment variables") |
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
def remove_code_blocks(text): |
|
|
|
|
|
text = re.sub(r"```[\s\S]*?```", "", text) |
|
|
|
|
|
text = re.sub(r"^( {4,}.*\n?)+", "", text, flags=re.MULTILINE) |
|
|
return text |
|
|
|
|
|
device='cuda' if torch.cuda.is_available() else 'cpu' |
|
|
st.set_page_config(page_title="SkinGPT", page_icon="🧬", layout="centered") |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def get_classifier(): |
|
|
classifier = SkinGPTClassifier() |
|
|
for module in [classifier.model.vit, |
|
|
classifier.model.q_former, |
|
|
classifier.model.llama]: |
|
|
module.eval() |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
return classifier |
|
|
|
|
|
if 'app_models' not in st.session_state: |
|
|
st.session_state.app_models = get_classifier() |
|
|
|
|
|
classifier = st.session_state.app_models |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
if "current_image" not in st.session_state: |
|
|
st.session_state.current_image = None |
|
|
|
|
|
|
|
|
def export_chat_to_pdf(messages): |
|
|
pdf = FPDF() |
|
|
pdf.add_page() |
|
|
pdf.set_font("Arial", size=12) |
|
|
for msg in messages: |
|
|
role = "You" if msg["role"] == "user" else "AI" |
|
|
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") |
|
|
buf = io.BytesIO() |
|
|
pdf_bytes = pdf.output(dest='S').encode('latin1') |
|
|
buf.write(pdf_bytes) |
|
|
buf.seek(0) |
|
|
return buf |
|
|
|
|
|
|
|
|
|
|
|
st.title("🧬 DermBOT — Skin AI Assistant") |
|
|
st.caption(f"🧠 Using model: SkinGPT") |
|
|
uploaded_file = st.file_uploader( |
|
|
"Upload a skin image", |
|
|
type=["jpg", "jpeg", "png"], |
|
|
key="file_uploader" |
|
|
) |
|
|
|
|
|
if uploaded_file is not None and uploaded_file != st.session_state.current_image: |
|
|
st.session_state.messages = [] |
|
|
st.session_state.current_image = uploaded_file |
|
|
classifier.current_image_embeddings = None |
|
|
|
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
st.image(image, caption="Uploaded image", use_column_width=True) |
|
|
with st.spinner("Analyzing the image..."): |
|
|
result = classifier.predict(image, reuse_embeddings=False) |
|
|
print("result in app : ", result["diagnosis"]) |
|
|
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]}) |
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
|
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask a follow-up question..."): |
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.markdown(remove_code_blocks(prompt)) |
|
|
|
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
image = Image.open(st.session_state.current_image).convert("RGB") |
|
|
if len(st.session_state.messages) > 1: |
|
|
conversation_context = "\n".join( |
|
|
f"{m['role']}: {m['content']}" |
|
|
for m in st.session_state.messages[:-1] |
|
|
) |
|
|
augmented_prompt = ( |
|
|
f"Conversation history:\n{conversation_context}\n\n" |
|
|
f"Current question: {prompt}" |
|
|
) |
|
|
result = classifier.predict(image, user_input=augmented_prompt, reuse_embeddings=True) |
|
|
else: |
|
|
result = classifier.predict(image, user_input=prompt, reuse_embeddings=False) |
|
|
|
|
|
|
|
|
st.markdown(result["diagnosis"]) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]}) |
|
|
|
|
|
if st.session_state.messages and st.button("📄 Download Chat as PDF"): |
|
|
pdf_file = export_chat_to_pdf(st.session_state.messages) |
|
|
st.download_button( |
|
|
"Download PDF", |
|
|
data=pdf_file, |
|
|
file_name="skingpt_chat_history.pdf", |
|
|
mime="application/pdf" |
|
|
) |