SkinGPT / app.py
KeerthiVM's picture
Prompt change
8927e5b
import torch
import random
import numpy as np
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
import streamlit as st
import io
from PIL import Image
import os
from transformers import logging
from SkinGPT import SkinGPTClassifier
from fpdf import FPDF
import nest_asyncio
nest_asyncio.apply()
torch.set_default_dtype(torch.float32) # Main computations in float32
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
logging.set_verbosity_error()
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("Hugging Face token not found in environment variables")
import warnings
warnings.filterwarnings("ignore")
import re
def remove_code_blocks(text):
# Remove triple backtick code blocks
text = re.sub(r"```[\s\S]*?```", "", text)
# Remove lines that start with 4 or more spaces (Markdown indented code blocks)
text = re.sub(r"^( {4,}.*\n?)+", "", text, flags=re.MULTILINE)
return text
device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="SkinGPT", page_icon="🧬", layout="centered")
@st.cache_resource(show_spinner=False)
def get_classifier():
classifier = SkinGPTClassifier()
for module in [classifier.model.vit,
classifier.model.q_former,
classifier.model.llama]:
module.eval()
for param in module.parameters():
param.requires_grad = False
return classifier
if 'app_models' not in st.session_state:
st.session_state.app_models = get_classifier()
classifier = st.session_state.app_models
# === Session Init ===
if "messages" not in st.session_state:
st.session_state.messages = []
if "current_image" not in st.session_state:
st.session_state.current_image = None
# === PDF Export ===
def export_chat_to_pdf(messages):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
for msg in messages:
role = "You" if msg["role"] == "user" else "AI"
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
buf = io.BytesIO()
pdf_bytes = pdf.output(dest='S').encode('latin1')
buf.write(pdf_bytes)
buf.seek(0)
return buf
# === App UI ===
st.title("🧬 DermBOT — Skin AI Assistant")
st.caption(f"🧠 Using model: SkinGPT")
uploaded_file = st.file_uploader(
"Upload a skin image",
type=["jpg", "jpeg", "png"],
key="file_uploader"
)
if uploaded_file is not None and uploaded_file != st.session_state.current_image:
st.session_state.messages = []
st.session_state.current_image = uploaded_file
classifier.current_image_embeddings = None
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded image", use_column_width=True)
with st.spinner("Analyzing the image..."):
result = classifier.predict(image, reuse_embeddings=False)
print("result in app : ", result["diagnosis"])
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
for message in st.session_state.messages:
with st.chat_message(message["role"]):
# st.markdown(remove_code_blocks(message["content"]))
st.markdown(message["content"])
# st.text(message["content"])
# for message in st.session_state.messages:
# role = "You" if message["role"] == "user" else "assistant"
# st.markdown(f"**{role}:** {message['content']}")
# === Chat Interface ===
if prompt := st.chat_input("Ask a follow-up question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(remove_code_blocks(prompt))
# st.markdown(f"**You:** {prompt}")
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
image = Image.open(st.session_state.current_image).convert("RGB")
if len(st.session_state.messages) > 1:
conversation_context = "\n".join(
f"{m['role']}: {m['content']}"
for m in st.session_state.messages[:-1]
)
augmented_prompt = (
f"Conversation history:\n{conversation_context}\n\n"
f"Current question: {prompt}"
)
result = classifier.predict(image, user_input=augmented_prompt, reuse_embeddings=True)
else:
result = classifier.predict(image, user_input=prompt, reuse_embeddings=False)
# st.markdown(remove_code_blocks(result["diagnosis"]))
st.markdown(result["diagnosis"])
# st.text(result["diagnosis"])
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
if st.session_state.messages and st.button("📄 Download Chat as PDF"):
pdf_file = export_chat_to_pdf(st.session_state.messages)
st.download_button(
"Download PDF",
data=pdf_file,
file_name="skingpt_chat_history.pdf",
mime="application/pdf"
)