jrrade's picture
added necessary files for app to run on HF
74c02ff
# --- START OF FILE app_main.py ---
import streamlit as st
import torch
import torch.nn.functional as F
from torchvision import transforms
# Ensure classification_model is importable (e.g., it's in the same directory orPYTHONPATH)
try:
from classification_model.model import VGG16Model
except ImportError:
st.error("Could not import VGG16Model. Make sure 'classification_model/model.py' exists and is importable.")
st.stop()
from PIL import Image
import os
from dotenv import load_dotenv
from langchain.schema import SystemMessage, HumanMessage, AIMessage, BaseMessage
import io
# Import graph logic from the other file
try:
from graph_logic import initialize_llm, create_chat_graph, GraphState
except ImportError:
st.error("Could not import from graph_logic.py. Make sure the file exists in the same directory.")
st.stop()
# Load environment variables from .env file
load_dotenv()
# --- Helper Functions (Specific to App/UI) ---
@st.cache_resource
def load_vgg16(path='./classification_model/best_model.pth', device='cpu'):
"""Loads the pre-trained VGG16 model."""
model = VGG16Model()
if not os.path.exists(path):
st.error(f"Model file not found at {path}. Please ensure the model is present.")
st.stop()
try:
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
except Exception as e:
st.error(f"Error loading model state dict: {e}")
st.stop()
return model
@st.cache_resource
def load_class_labels() -> dict:
"""Loads the class labels."""
return {0: 'good_images', 1: 'Imaging Artifact', 2: 'Not Tracking', 3: 'Tip Contamination'}
def preprocess_image(img: Image.Image):
"""Preprocesses the image for the VGG16 model."""
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.3718, 0.1738, 0.0571], std=[0.2095, 0.2124, 0.1321]),
])
img_tensor = preprocess(img).unsqueeze(0)
return img_tensor
def predict_image_class(img: Image.Image, model, class_names: dict) -> str:
"""Predicts the class label for a given image."""
img_tensor = preprocess_image(img)
with torch.no_grad():
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_prob, top_idx = torch.topk(probs, 1)
class_label = class_names.get(top_idx.item(), "Unknown Class")
return class_label
def download_chat_history() -> str:
"""Generates the chat history text for download."""
if "messages" not in st.session_state or not st.session_state.messages:
return ""
output = io.StringIO()
start_index = 1 if isinstance(st.session_state.messages[0], SystemMessage) else 0
for msg in st.session_state.messages[start_index:]:
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
output.write(f"{role}: {msg.content}\n")
return output.getvalue()
# --- Streamlit UI ---
st.set_page_config(page_title="AFM Defect Assistant (LangGraph)", page_icon="🔬")
st.title("🔬 AFM Image Defect Classification + LLM-based AFM Assistant")
st.write("Upload an AFM image, get a classification, and chat with an AI assistant about the result.")
# --- Sidebar Controls ---
st.sidebar.header("Settings")
# Model Selection
provider = st.sidebar.selectbox("LLM Provider", ["OpenAI", "Anthropic"], key="provider_select")
api_key = None
api_key_name = ""
if provider == "OpenAI":
default_model = "gpt-4o"
available_models = ["gpt-4o", "o3-mini"]
api_key = os.getenv("OPENAI_API_KEY")
api_key_name = "OPENAI_API_KEY"
elif provider == "Anthropic":
default_model = "claude-3-5-sonnet-latest"
available_models = ["claude-3-5-sonnet-latest", "claude-3-7-sonnet-latest"]
api_key = os.getenv("ANTHROPIC_API_KEY")
api_key_name = "ANTHROPIC_API_KEY"
else:
st.sidebar.error("Invalid provider selected.")
st.stop()
# Display warning if API key is missing
if not api_key:
st.sidebar.warning(f"{provider} API key not found. Please set the {api_key_name} environment variable.")
model_name = st.sidebar.selectbox(f"Choose {provider} Model", available_models, index=available_models.index(default_model), key="model_name_select")
temperature = st.sidebar.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.05, key="temp_slider")
# Clear Chat Button
if st.sidebar.button("Start New Session", key="clear_chat_button"):
keys_to_clear = ["messages", "current_label", "uploaded_file_state", "llm", "graph"]
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
st.rerun()
# --- Main Page Logic ---
# File Uploader
uploaded_file = st.file_uploader("Upload an AFM image", type=["jpg", "jpeg", "png"], key="file_uploader")
# Manage state based on uploaded file
if uploaded_file is not None:
new_file_id = uploaded_file.file_id
# Check if it's a new file or the same one to avoid re-processing
if "uploaded_file_state" not in st.session_state or st.session_state.uploaded_file_state["id"] != new_file_id:
st.session_state.uploaded_file_state = {"id": new_file_id, "name": uploaded_file.name}
# Clear previous chat/state if a new file is uploaded
keys_to_reset = ["messages", "current_label", "llm", "graph"]
for key in keys_to_reset:
if key in st.session_state:
del st.session_state[key]
# --- Image Processing and Classification ---
try:
img = Image.open(uploaded_file).convert("RGB")
st.image(img, caption=f"Uploaded: {st.session_state.uploaded_file_state['name']}", width=200)
model = load_vgg16()
class_names = load_class_labels()
with st.spinner("Classifying image..."):
class_label = predict_image_class(img, model, class_names)
st.success(f"**Predicted Class Label:** {class_label}")
# --- LLM and Graph Initialization ---
label_changed = ("current_label" not in st.session_state or
st.session_state.current_label != class_label)
# Initialize LLM and Graph if not present or if label changed
if "llm" not in st.session_state or "graph" not in st.session_state or label_changed:
if not api_key:
st.error(f"Cannot proceed without {api_key_name}. Please set it in your environment variables.")
st.stop()
try:
# Initialize LLM using the function from graph_logic.py
st.session_state.llm = initialize_llm(provider, model_name, temperature, api_key)
# Create the graph using the function from graph_logic.py
st.session_state.graph = create_chat_graph(st.session_state.llm)
st.session_state.current_label = class_label
# Define the system prompt and initial state for the graph
system_prompt_content = (
f"You are an expert in atomic force microscopy (AFM). "
f"The user has uploaded an image and it has '{class_label}' defect. "
"Your role is to help the user understand this defect, potential causes, "
"and how to potentially avoid or address the issue represented by this defect. "
"Provide concise, technically accurate, and helpful answers. Avoid speculation if unsure."
)
system_message = SystemMessage(content=system_prompt_content)
st.session_state.messages = [system_message] # Initialize message history
except Exception as e:
st.error(f"Failed to initialize LLM or Graph: {e}")
st.stop()
# --- Chat Interface ---
st.divider()
st.header(f"Chat about '{st.session_state.current_label}'")
# Display existing messages (skip system message)
if "messages" in st.session_state:
for i, msg in enumerate(st.session_state.messages):
if i == 0 and isinstance(msg, SystemMessage):
continue
with st.chat_message("user" if isinstance(msg, HumanMessage) else "assistant"):
st.markdown(msg.content)
# Chat input
if prompt := st.chat_input("Ask a question about the detected defect..."):
# Add user message to state and display it
st.session_state.messages.append(HumanMessage(content=prompt))
with st.chat_message("user"):
st.markdown(prompt)
# Prepare the input state for the graph
current_graph_state: GraphState = {"messages": st.session_state.messages}
# Invoke the graph
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
# Invoke the graph with the current state
response_state = st.session_state.graph.invoke(current_graph_state)
# Update session state with the full response history from the graph
st.session_state.messages = response_state['messages']
ai_response_content = st.session_state.messages[-1].content
st.markdown(ai_response_content)
except Exception as e:
st.error(f"Error during chat generation: {e}")
# Roll back user message if AI fails
if st.session_state.messages and isinstance(st.session_state.messages[-1], HumanMessage):
st.session_state.messages.pop()
# --- Download Chat Button ---
if len(st.session_state.get("messages", [])) > 1: # Show only if conversation started
st.divider()
chat_text = download_chat_history()
st.download_button(
label="Download Chat History",
data=chat_text,
file_name=f"afm_chat_{st.session_state.current_label.replace(' ', '_')}.txt",
mime="text/plain"
)
except Exception as e:
st.error(f"An error occurred processing the image or during chat setup: {e}")
if "uploaded_file_state" in st.session_state:
del st.session_state.uploaded_file_state # Reset if critical error occurs
elif "uploaded_file_state" in st.session_state:
# If file uploader is cleared by the user after a file was processed
keys_to_clear = ["messages", "current_label", "uploaded_file_state", "llm", "graph"]
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
st.rerun()
else:
st.info("Please upload an image to start the analysis and chat.")
# --- END OF FILE app_main.py ---