Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import os | |
| import re | |
| import base64 | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from huggingface_hub import login | |
| TRANSFORMERS_AVAILABLE = True | |
| except ImportError as e: | |
| st.error(f"Failed to import transformers: {str(e)}. Please install it with `pip install transformers`.") | |
| TRANSFORMERS_AVAILABLE = False | |
| # Set page configuration | |
| st.set_page_config(page_title="WizNerd Insp", page_icon="π", layout="centered") | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Tw+Cen+MT&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'Tw Cen MT', sans-serif !important; | |
| } | |
| .stTable table { | |
| font-family: 'Tw Cen MT', sans-serif !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Load Hugging Face token | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Model name | |
| MODEL_NAME = "amiguel/class_insp_program" | |
| # ============================================================================= | |
| # FIXED: Label mapping must match EXACTLY what the model was trained with | |
| # The model was trained with 13 classes (Flare Tip and Flare TIP were merged) | |
| # ============================================================================= | |
| LABEL_TO_CLASS = { | |
| 0: "Campaign", | |
| 1: "Corrosion Monitoring", | |
| 2: "Flare Tip", # This now covers both "Flare Tip" and "Flare TIP" | |
| 3: "FU Items", | |
| 4: "Intelligent Pigging", | |
| 5: "Lifting", | |
| 6: "Non Structural Tank", | |
| 7: "Piping", | |
| 8: "Pressure Safety Device", | |
| 9: "Pressure Vessel (VIE)", | |
| 10: "Pressure Vessel (VII)", | |
| 11: "Structure", | |
| 12: "Flame Arrestor" | |
| } | |
| NUM_LABELS = len(LABEL_TO_CLASS) # Should be 13 | |
| # Required columns - UPDATED | |
| REQUIRED_COLS = ["MaintItem text", "Functional Loc.", "Description"] | |
| # Title | |
| st.title("π Scope Inspektor π") | |
| # Avatars | |
| USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" | |
| BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("Upload Documents π") | |
| uploaded_file = st.file_uploader( | |
| "Choose an XLSX or CSV file", | |
| type=["xlsx", "csv"], | |
| label_visibility="collapsed" | |
| ) | |
| # Show model info | |
| st.markdown("---") | |
| st.markdown(f"**Model:** `{MODEL_NAME}`") | |
| st.markdown(f"**Classes:** {NUM_LABELS}") | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "file_processed" not in st.session_state: | |
| st.session_state.file_processed = False | |
| if "file_data" not in st.session_state: | |
| st.session_state.file_data = None | |
| if "last_uploaded_file" not in st.session_state: | |
| st.session_state.last_uploaded_file = None | |
| # File processing function with cache | |
| def process_file(uploaded_file, _cache_key): | |
| if uploaded_file is None: | |
| return None | |
| try: | |
| # Read file | |
| if uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
| df = pd.read_excel(uploaded_file) | |
| else: | |
| df = pd.read_csv(uploaded_file) | |
| # Check if all required columns are present | |
| missing_cols = [col for col in REQUIRED_COLS if col not in df.columns] | |
| if missing_cols: | |
| st.error(f"Missing required columns: {', '.join(missing_cols)}. Please upload a file with 'MaintItem text', 'Functional Loc.', and 'Description'.") | |
| return None | |
| # Pre-process and concatenate required columns for classification | |
| df = df.dropna(subset=REQUIRED_COLS, how='all') # Keep rows with at least some data | |
| df["input_text"] = df[REQUIRED_COLS].apply( | |
| lambda row: " ".join([re.sub(r'\s+', ' ', str(val).lower().strip()) for val in row if pd.notna(val)]), axis=1 | |
| ) | |
| return {"type": "table", "content": df, "original_df": df.copy()} | |
| except Exception as e: | |
| st.error(f"π Error processing file: {str(e)}") | |
| return None | |
| # Model loading function - FIXED | |
| def load_model(hf_token): | |
| if not TRANSFORMERS_AVAILABLE: | |
| return None | |
| try: | |
| if not hf_token: | |
| st.error("π Please set the HF_TOKEN environment variable.") | |
| return None | |
| login(token=hf_token) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token) | |
| # ================================================================= | |
| # FIXED: Load model WITHOUT specifying num_labels | |
| # Let it auto-detect from config.json, or use ignore_mismatched_sizes | |
| # ================================================================= | |
| try: | |
| # First try: Load without specifying num_labels (uses config.json) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_NAME, | |
| token=hf_token | |
| ) | |
| except Exception as e1: | |
| # Fallback: Try with explicit num_labels and ignore size mismatch | |
| st.warning(f"Auto-load failed, trying with explicit config: {str(e1)}") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=NUM_LABELS, | |
| token=hf_token, | |
| ignore_mismatched_sizes=True # This allows loading even if sizes differ | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| # Log successful load | |
| st.sidebar.success(f"β Model loaded on {device}") | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"π€ Model loading failed: {str(e)}") | |
| import traceback | |
| st.error(f"Full traceback:\n```\n{traceback.format_exc()}\n```") | |
| return None | |
| # Classification function - IMPROVED with confidence scores | |
| def classify_instruction(prompt, context, model, tokenizer, return_confidence=False): | |
| model.eval() | |
| device = model.device | |
| if isinstance(context, pd.DataFrame): | |
| predictions = [] | |
| confidences = [] | |
| # Process in batches for efficiency | |
| batch_size = 32 | |
| texts = context["input_text"].tolist() | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| # Prepare inputs | |
| inputs = tokenizer( | |
| batch_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=128 | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| batch_preds = outputs.logits.argmax(dim=-1).cpu().numpy() | |
| batch_confs = probs.max(dim=-1).values.cpu().numpy() | |
| for pred, conf in zip(batch_preds, batch_confs): | |
| # Handle case where prediction ID exceeds our mapping | |
| if pred in LABEL_TO_CLASS: | |
| predictions.append(LABEL_TO_CLASS[pred]) | |
| else: | |
| predictions.append(f"Unknown ({pred})") | |
| confidences.append(float(conf)) | |
| if return_confidence: | |
| return predictions, confidences | |
| return predictions | |
| else: | |
| # Single text classification | |
| text = str(context) if context else prompt | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| prediction = outputs.logits.argmax().item() | |
| confidence = probs[0, prediction].item() | |
| pred_label = LABEL_TO_CLASS.get(prediction, f"Unknown ({prediction})") | |
| if return_confidence: | |
| return pred_label, confidence | |
| return pred_label | |
| # Excel download function - inserts Item Class before MaintItem text | |
| def get_excel_download_link(df, filename="predicted_classes.xlsx"): | |
| from io import BytesIO | |
| # Create output dataframe with Item Class before MaintItem text | |
| output_df = df.copy() | |
| # Get column list and find position of MaintItem text | |
| cols = list(output_df.columns) | |
| if "Item Class" in cols: | |
| cols.remove("Item Class") | |
| if "Confidence" in cols: | |
| cols.remove("Confidence") | |
| # Find MaintItem text position | |
| if "MaintItem text" in cols: | |
| maint_idx = cols.index("MaintItem text") | |
| # Insert Item Class and Confidence before MaintItem text | |
| cols.insert(maint_idx, "Confidence") | |
| cols.insert(maint_idx, "Item Class") | |
| else: | |
| # Fallback: put at beginning | |
| cols.insert(0, "Confidence") | |
| cols.insert(0, "Item Class") | |
| # Remove input_text column if present (internal use only) | |
| if "input_text" in cols: | |
| cols.remove("input_text") | |
| # Reorder columns | |
| output_df = output_df[[c for c in cols if c in output_df.columns]] | |
| # Save to Excel | |
| buffer = BytesIO() | |
| output_df.to_excel(buffer, index=False, engine='openpyxl') | |
| buffer.seek(0) | |
| b64 = base64.b64encode(buffer.getvalue()).decode() | |
| href = f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">π₯ Download Excel</a>' | |
| return href | |
| # CSV download function | |
| def get_csv_download_link(df, filename="predicted_classes.csv"): | |
| # Remove input_text column | |
| output_df = df.drop(columns=["input_text"], errors="ignore") | |
| # Reorder to put Item Class before MaintItem text | |
| cols = list(output_df.columns) | |
| if "Item Class" in cols and "MaintItem text" in cols: | |
| cols.remove("Item Class") | |
| if "Confidence" in cols: | |
| cols.remove("Confidence") | |
| maint_idx = cols.index("MaintItem text") | |
| cols.insert(maint_idx, "Confidence") | |
| cols.insert(maint_idx, "Item Class") | |
| output_df = output_df[[c for c in cols if c in output_df.columns]] | |
| csv = output_df.to_csv(index=False) | |
| b64 = base64.b64encode(csv.encode()).decode() | |
| href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">π₯ Download CSV</a>' | |
| return href | |
| # Load model | |
| if "model" not in st.session_state: | |
| model_data = load_model(HF_TOKEN) | |
| if model_data is None and TRANSFORMERS_AVAILABLE: | |
| st.error("Failed to load model. Check HF_TOKEN.") | |
| st.stop() | |
| elif TRANSFORMERS_AVAILABLE: | |
| st.session_state.model, st.session_state.tokenizer = model_data | |
| model = st.session_state.get("model") | |
| tokenizer = st.session_state.get("tokenizer") | |
| # Check for new file upload and clear cache | |
| if uploaded_file and uploaded_file != st.session_state.last_uploaded_file: | |
| st.cache_data.clear() | |
| st.session_state.file_processed = False | |
| st.session_state.file_data = None | |
| st.session_state.last_uploaded_file = uploaded_file | |
| # Process uploaded file once | |
| if uploaded_file and not st.session_state.file_processed: | |
| cache_key = f"{uploaded_file.name}_{uploaded_file.size}" | |
| file_data = process_file(uploaded_file, cache_key) | |
| if file_data: | |
| st.session_state.file_data = file_data | |
| st.session_state.file_processed = True | |
| st.write(f"β File uploaded with {len(file_data['content'])} rows. Please provide an instruction to classify.") | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR | |
| with st.chat_message(message["role"], avatar=avatar): | |
| st.markdown(message["content"]) | |
| # Chat input handling | |
| if prompt := st.chat_input("Ask your inspection question..."): | |
| if not TRANSFORMERS_AVAILABLE: | |
| st.error("Transformers library not available.") | |
| st.stop() | |
| # Add user message | |
| with st.chat_message("user", avatar=USER_AVATAR): | |
| st.markdown(prompt) | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Handle response | |
| if model and tokenizer: | |
| try: | |
| with st.chat_message("assistant", avatar=BOT_AVATAR): | |
| if st.session_state.file_data: | |
| file_data = st.session_state.file_data | |
| if file_data["type"] == "table": | |
| with st.spinner("Classifying..."): | |
| predictions, confidences = classify_instruction( | |
| prompt, file_data["content"], model, tokenizer, return_confidence=True | |
| ) | |
| # Add predictions to dataframe | |
| result_df = file_data["content"].copy() | |
| result_df["Item Class"] = predictions | |
| result_df["Confidence"] = [f"{c:.2%}" for c in confidences] | |
| # Display preview (first 10 rows) | |
| st.write("**Predicted Item Classes (preview):**") | |
| display_cols = ["Item Class", "Confidence"] + REQUIRED_COLS | |
| st.dataframe(result_df[display_cols].head(10), use_container_width=True) | |
| # Stats | |
| st.write(f"**Total rows classified:** {len(predictions)}") | |
| st.write("**Class distribution:**") | |
| st.write(result_df["Item Class"].value_counts()) | |
| # Average confidence | |
| avg_conf = sum(confidences) / len(confidences) | |
| st.write(f"**Average confidence:** {avg_conf:.2%}") | |
| # Download links | |
| st.markdown("---") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown(get_excel_download_link(result_df), unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(get_csv_download_link(result_df), unsafe_allow_html=True) | |
| response = f"β Classification completed for {len(predictions)} rows." | |
| else: | |
| predicted_class, confidence = classify_instruction( | |
| prompt, file_data["content"], model, tokenizer, return_confidence=True | |
| ) | |
| response = f"The Item Class is: **{predicted_class}** (confidence: {confidence:.2%})" | |
| else: | |
| predicted_class, confidence = classify_instruction( | |
| prompt, "", model, tokenizer, return_confidence=True | |
| ) | |
| response = f"The Item Class is: **{predicted_class}** (confidence: {confidence:.2%})" | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| st.error(f"β‘ Classification error: {str(e)}") | |
| import traceback | |
| st.error(f"```\n{traceback.format_exc()}\n```") | |
| else: | |
| st.error("π€ Model not loaded!") |