File size: 5,118 Bytes
3c32556
2caff32
 
 
 
 
 
3c32556
 
 
af6df0b
264e606
2e04d58
 
 
51b26bc
 
6c7e865
 
af6df0b
 
3c32556
 
 
 
 
 
 
3b126a3
 
bea4518
3b126a3
 
d35e524
 
 
 
3c32556
2e04d58
 
 
 
 
b3795f6
2caff32
 
 
 
 
 
 
 
 
3c32556
2e04d58
 
 
 
3c32556
 
 
 
 
2e04d58
 
3c32556
 
 
 
 
 
 
 
 
 
1ddb49e
 
3c32556
 
 
 
 
 
 
2e04d58
 
 
 
 
 
 
 
 
7f3d8d1
2e04d58
3c32556
2e04d58
 
7f3d8d1
0e90055
2e04d58
 
eb5f347
732a6f0
 
8927e5b
 
732a6f0
 
 
 
bea4518
2e04d58
 
 
732a6f0
 
eb5f347
732a6f0
2e04d58
 
 
7f3d8d1
2e04d58
 
 
7f3d8d1
2e04d58
 
 
 
 
7f3d8d1
73f97d4
7f3d8d1
2e04d58
2ea890b
8927e5b
 
2e04d58
 
 
3c32556
2e04d58
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import random
import numpy as np
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
import streamlit as st
import io
from PIL import Image
import os
from transformers import logging
from SkinGPT import SkinGPTClassifier
from fpdf import FPDF
import nest_asyncio
nest_asyncio.apply()
torch.set_default_dtype(torch.float32)  # Main computations in float32
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
logging.set_verbosity_error()
token = os.getenv("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token not found in environment variables")
import warnings

warnings.filterwarnings("ignore")

import re


def remove_code_blocks(text):
    # Remove triple backtick code blocks
    text = re.sub(r"```[\s\S]*?```", "", text)
    # Remove lines that start with 4 or more spaces (Markdown indented code blocks)
    text = re.sub(r"^( {4,}.*\n?)+", "", text, flags=re.MULTILINE)
    return text

device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="SkinGPT", page_icon="🧬", layout="centered")


@st.cache_resource(show_spinner=False)
def get_classifier():
    classifier = SkinGPTClassifier()
    for module in [classifier.model.vit,
                   classifier.model.q_former,
                   classifier.model.llama]:
        module.eval()
        for param in module.parameters():
            param.requires_grad = False

    return classifier

if 'app_models' not in st.session_state:
    st.session_state.app_models = get_classifier()

classifier = st.session_state.app_models

# === Session Init ===
if "messages" not in st.session_state:
    st.session_state.messages = []

if "current_image" not in st.session_state:
    st.session_state.current_image = None

# === PDF Export ===
def export_chat_to_pdf(messages):
    pdf = FPDF()
    pdf.add_page()
    pdf.set_font("Arial", size=12)
    for msg in messages:
        role = "You" if msg["role"] == "user" else "AI"
        pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
    buf = io.BytesIO()
    pdf_bytes = pdf.output(dest='S').encode('latin1')
    buf.write(pdf_bytes)
    buf.seek(0)
    return buf

# === App UI ===

st.title("🧬 DermBOT — Skin AI Assistant")
st.caption(f"🧠 Using model: SkinGPT")
uploaded_file = st.file_uploader(
    "Upload a skin image",
    type=["jpg", "jpeg", "png"],
    key="file_uploader"
)

if uploaded_file is not None and uploaded_file != st.session_state.current_image:
    st.session_state.messages = []
    st.session_state.current_image = uploaded_file
    classifier.current_image_embeddings = None

    image = Image.open(uploaded_file).convert("RGB")
    st.image(image, caption="Uploaded image", use_column_width=True)
    with st.spinner("Analyzing the image..."):
        result = classifier.predict(image, reuse_embeddings=False)
        print("result in app : ", result["diagnosis"])
    st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        # st.markdown(remove_code_blocks(message["content"]))
        st.markdown(message["content"])
        # st.text(message["content"])

# for message in st.session_state.messages:
#     role = "You" if message["role"] == "user" else "assistant"
#     st.markdown(f"**{role}:** {message['content']}")

# === Chat Interface ===
if prompt := st.chat_input("Ask a follow-up question..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(remove_code_blocks(prompt))

    # st.markdown(f"**You:** {prompt}")

    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            image = Image.open(st.session_state.current_image).convert("RGB")
            if len(st.session_state.messages) > 1:
                conversation_context = "\n".join(
                    f"{m['role']}: {m['content']}"
                    for m in st.session_state.messages[:-1]
                )
                augmented_prompt = (
                    f"Conversation history:\n{conversation_context}\n\n"
                    f"Current question: {prompt}"
                )
                result = classifier.predict(image, user_input=augmented_prompt, reuse_embeddings=True)
            else:
                result = classifier.predict(image, user_input=prompt, reuse_embeddings=False)

            # st.markdown(remove_code_blocks(result["diagnosis"]))
            st.markdown(result["diagnosis"])
            # st.text(result["diagnosis"])
            st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})

if st.session_state.messages and st.button("📄 Download Chat as PDF"):
    pdf_file = export_chat_to_pdf(st.session_state.messages)
    st.download_button(
        "Download PDF",
        data=pdf_file,
        file_name="skingpt_chat_history.pdf",
        mime="application/pdf"
    )