classfinetune / app.py
amiguel's picture
Update app.py
104349c verified
import streamlit as st
import pandas as pd
import torch
import os
import re
import base64
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import login
TRANSFORMERS_AVAILABLE = True
except ImportError as e:
st.error(f"Failed to import transformers: {str(e)}. Please install it with `pip install transformers`.")
TRANSFORMERS_AVAILABLE = False
# Set page configuration
st.set_page_config(page_title="WizNerd Insp", page_icon="πŸš€", layout="centered")
# Custom CSS
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Tw+Cen+MT&display=swap');
html, body, [class*="css"] {
font-family: 'Tw Cen MT', sans-serif !important;
}
.stTable table {
font-family: 'Tw Cen MT', sans-serif !important;
}
</style>
""", unsafe_allow_html=True)
# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
# Model name
MODEL_NAME = "amiguel/class_insp_program"
# =============================================================================
# FIXED: Label mapping must match EXACTLY what the model was trained with
# The model was trained with 13 classes (Flare Tip and Flare TIP were merged)
# =============================================================================
LABEL_TO_CLASS = {
0: "Campaign",
1: "Corrosion Monitoring",
2: "Flare Tip", # This now covers both "Flare Tip" and "Flare TIP"
3: "FU Items",
4: "Intelligent Pigging",
5: "Lifting",
6: "Non Structural Tank",
7: "Piping",
8: "Pressure Safety Device",
9: "Pressure Vessel (VIE)",
10: "Pressure Vessel (VII)",
11: "Structure",
12: "Flame Arrestor"
}
NUM_LABELS = len(LABEL_TO_CLASS) # Should be 13
# Required columns - UPDATED
REQUIRED_COLS = ["MaintItem text", "Functional Loc.", "Description"]
# Title
st.title("πŸš€ Scope Inspektor πŸš€")
# Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# Sidebar
with st.sidebar:
st.header("Upload Documents πŸ“‚")
uploaded_file = st.file_uploader(
"Choose an XLSX or CSV file",
type=["xlsx", "csv"],
label_visibility="collapsed"
)
# Show model info
st.markdown("---")
st.markdown(f"**Model:** `{MODEL_NAME}`")
st.markdown(f"**Classes:** {NUM_LABELS}")
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "file_processed" not in st.session_state:
st.session_state.file_processed = False
if "file_data" not in st.session_state:
st.session_state.file_data = None
if "last_uploaded_file" not in st.session_state:
st.session_state.last_uploaded_file = None
# File processing function with cache
@st.cache_data
def process_file(uploaded_file, _cache_key):
if uploaded_file is None:
return None
try:
# Read file
if uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
df = pd.read_excel(uploaded_file)
else:
df = pd.read_csv(uploaded_file)
# Check if all required columns are present
missing_cols = [col for col in REQUIRED_COLS if col not in df.columns]
if missing_cols:
st.error(f"Missing required columns: {', '.join(missing_cols)}. Please upload a file with 'MaintItem text', 'Functional Loc.', and 'Description'.")
return None
# Pre-process and concatenate required columns for classification
df = df.dropna(subset=REQUIRED_COLS, how='all') # Keep rows with at least some data
df["input_text"] = df[REQUIRED_COLS].apply(
lambda row: " ".join([re.sub(r'\s+', ' ', str(val).lower().strip()) for val in row if pd.notna(val)]), axis=1
)
return {"type": "table", "content": df, "original_df": df.copy()}
except Exception as e:
st.error(f"πŸ“„ Error processing file: {str(e)}")
return None
# Model loading function - FIXED
@st.cache_resource
def load_model(hf_token):
if not TRANSFORMERS_AVAILABLE:
return None
try:
if not hf_token:
st.error("πŸ” Please set the HF_TOKEN environment variable.")
return None
login(token=hf_token)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
# =================================================================
# FIXED: Load model WITHOUT specifying num_labels
# Let it auto-detect from config.json, or use ignore_mismatched_sizes
# =================================================================
try:
# First try: Load without specifying num_labels (uses config.json)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
token=hf_token
)
except Exception as e1:
# Fallback: Try with explicit num_labels and ignore size mismatch
st.warning(f"Auto-load failed, trying with explicit config: {str(e1)}")
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS,
token=hf_token,
ignore_mismatched_sizes=True # This allows loading even if sizes differ
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# Log successful load
st.sidebar.success(f"βœ… Model loaded on {device}")
return model, tokenizer
except Exception as e:
st.error(f"πŸ€– Model loading failed: {str(e)}")
import traceback
st.error(f"Full traceback:\n```\n{traceback.format_exc()}\n```")
return None
# Classification function - IMPROVED with confidence scores
def classify_instruction(prompt, context, model, tokenizer, return_confidence=False):
model.eval()
device = model.device
if isinstance(context, pd.DataFrame):
predictions = []
confidences = []
# Process in batches for efficiency
batch_size = 32
texts = context["input_text"].tolist()
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Prepare inputs
inputs = tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
batch_preds = outputs.logits.argmax(dim=-1).cpu().numpy()
batch_confs = probs.max(dim=-1).values.cpu().numpy()
for pred, conf in zip(batch_preds, batch_confs):
# Handle case where prediction ID exceeds our mapping
if pred in LABEL_TO_CLASS:
predictions.append(LABEL_TO_CLASS[pred])
else:
predictions.append(f"Unknown ({pred})")
confidences.append(float(conf))
if return_confidence:
return predictions, confidences
return predictions
else:
# Single text classification
text = str(context) if context else prompt
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
prediction = outputs.logits.argmax().item()
confidence = probs[0, prediction].item()
pred_label = LABEL_TO_CLASS.get(prediction, f"Unknown ({prediction})")
if return_confidence:
return pred_label, confidence
return pred_label
# Excel download function - inserts Item Class before MaintItem text
def get_excel_download_link(df, filename="predicted_classes.xlsx"):
from io import BytesIO
# Create output dataframe with Item Class before MaintItem text
output_df = df.copy()
# Get column list and find position of MaintItem text
cols = list(output_df.columns)
if "Item Class" in cols:
cols.remove("Item Class")
if "Confidence" in cols:
cols.remove("Confidence")
# Find MaintItem text position
if "MaintItem text" in cols:
maint_idx = cols.index("MaintItem text")
# Insert Item Class and Confidence before MaintItem text
cols.insert(maint_idx, "Confidence")
cols.insert(maint_idx, "Item Class")
else:
# Fallback: put at beginning
cols.insert(0, "Confidence")
cols.insert(0, "Item Class")
# Remove input_text column if present (internal use only)
if "input_text" in cols:
cols.remove("input_text")
# Reorder columns
output_df = output_df[[c for c in cols if c in output_df.columns]]
# Save to Excel
buffer = BytesIO()
output_df.to_excel(buffer, index=False, engine='openpyxl')
buffer.seek(0)
b64 = base64.b64encode(buffer.getvalue()).decode()
href = f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">πŸ“₯ Download Excel</a>'
return href
# CSV download function
def get_csv_download_link(df, filename="predicted_classes.csv"):
# Remove input_text column
output_df = df.drop(columns=["input_text"], errors="ignore")
# Reorder to put Item Class before MaintItem text
cols = list(output_df.columns)
if "Item Class" in cols and "MaintItem text" in cols:
cols.remove("Item Class")
if "Confidence" in cols:
cols.remove("Confidence")
maint_idx = cols.index("MaintItem text")
cols.insert(maint_idx, "Confidence")
cols.insert(maint_idx, "Item Class")
output_df = output_df[[c for c in cols if c in output_df.columns]]
csv = output_df.to_csv(index=False)
b64 = base64.b64encode(csv.encode()).decode()
href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">πŸ“₯ Download CSV</a>'
return href
# Load model
if "model" not in st.session_state:
model_data = load_model(HF_TOKEN)
if model_data is None and TRANSFORMERS_AVAILABLE:
st.error("Failed to load model. Check HF_TOKEN.")
st.stop()
elif TRANSFORMERS_AVAILABLE:
st.session_state.model, st.session_state.tokenizer = model_data
model = st.session_state.get("model")
tokenizer = st.session_state.get("tokenizer")
# Check for new file upload and clear cache
if uploaded_file and uploaded_file != st.session_state.last_uploaded_file:
st.cache_data.clear()
st.session_state.file_processed = False
st.session_state.file_data = None
st.session_state.last_uploaded_file = uploaded_file
# Process uploaded file once
if uploaded_file and not st.session_state.file_processed:
cache_key = f"{uploaded_file.name}_{uploaded_file.size}"
file_data = process_file(uploaded_file, cache_key)
if file_data:
st.session_state.file_data = file_data
st.session_state.file_processed = True
st.write(f"βœ… File uploaded with {len(file_data['content'])} rows. Please provide an instruction to classify.")
# Display chat messages
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Chat input handling
if prompt := st.chat_input("Ask your inspection question..."):
if not TRANSFORMERS_AVAILABLE:
st.error("Transformers library not available.")
st.stop()
# Add user message
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# Handle response
if model and tokenizer:
try:
with st.chat_message("assistant", avatar=BOT_AVATAR):
if st.session_state.file_data:
file_data = st.session_state.file_data
if file_data["type"] == "table":
with st.spinner("Classifying..."):
predictions, confidences = classify_instruction(
prompt, file_data["content"], model, tokenizer, return_confidence=True
)
# Add predictions to dataframe
result_df = file_data["content"].copy()
result_df["Item Class"] = predictions
result_df["Confidence"] = [f"{c:.2%}" for c in confidences]
# Display preview (first 10 rows)
st.write("**Predicted Item Classes (preview):**")
display_cols = ["Item Class", "Confidence"] + REQUIRED_COLS
st.dataframe(result_df[display_cols].head(10), use_container_width=True)
# Stats
st.write(f"**Total rows classified:** {len(predictions)}")
st.write("**Class distribution:**")
st.write(result_df["Item Class"].value_counts())
# Average confidence
avg_conf = sum(confidences) / len(confidences)
st.write(f"**Average confidence:** {avg_conf:.2%}")
# Download links
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
st.markdown(get_excel_download_link(result_df), unsafe_allow_html=True)
with col2:
st.markdown(get_csv_download_link(result_df), unsafe_allow_html=True)
response = f"βœ… Classification completed for {len(predictions)} rows."
else:
predicted_class, confidence = classify_instruction(
prompt, file_data["content"], model, tokenizer, return_confidence=True
)
response = f"The Item Class is: **{predicted_class}** (confidence: {confidence:.2%})"
else:
predicted_class, confidence = classify_instruction(
prompt, "", model, tokenizer, return_confidence=True
)
response = f"The Item Class is: **{predicted_class}** (confidence: {confidence:.2%})"
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
except Exception as e:
st.error(f"⚑ Classification error: {str(e)}")
import traceback
st.error(f"```\n{traceback.format_exc()}\n```")
else:
st.error("πŸ€– Model not loaded!")