Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| from transformers import TapexTokenizer, BartForConditionalGeneration | |
| import xml.etree.ElementTree as ET | |
| from io import StringIO | |
| import logging | |
| from datetime import datetime | |
| import time | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_model(): | |
| """ | |
| Load and cache the TAPEX model and tokenizer using Streamlit's caching | |
| """ | |
| try: | |
| tokenizer = TapexTokenizer.from_pretrained( | |
| "microsoft/tapex-large-finetuned-wtq", | |
| model_max_length=1024 | |
| ) | |
| model = BartForConditionalGeneration.from_pretrained( | |
| "microsoft/tapex-large-finetuned-wtq" | |
| ) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = model.to(device) | |
| model.eval() | |
| return tokenizer, model | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None, None | |
| def parse_xml_to_dataframe(xml_string: str): | |
| """ | |
| Parse XML string to DataFrame with error handling | |
| """ | |
| try: | |
| tree = ET.parse(StringIO(xml_string)) | |
| root = tree.getroot() | |
| data = [] | |
| columns = set() | |
| # First pass: collect all possible columns | |
| for record in root.findall('.//record'): | |
| columns.update(elem.tag for elem in record) | |
| # Second pass: create data rows | |
| for record in root.findall('.//record'): | |
| row_data = {col: None for col in columns} | |
| for elem in record: | |
| row_data[elem.tag] = elem.text | |
| data.append(row_data) | |
| df = pd.DataFrame(data) | |
| # Convert numeric columns (automatically detect) | |
| for col in df.columns: | |
| try: | |
| df[col] = pd.to_numeric(df[col]) | |
| except: | |
| continue | |
| return df, None | |
| except Exception as e: | |
| return None, f"Error parsing XML: {str(e)}" | |
| def process_query(tokenizer, model, df, query: str): | |
| """ | |
| Process a single query using the TAPEX model | |
| """ | |
| try: | |
| start_time = time.time() | |
| # Handle direct DataFrame operations for common queries | |
| query_lower = query.lower() | |
| if "highest" in query_lower or "maximum" in query_lower: | |
| for col in df.select_dtypes(include=['number']).columns: | |
| if col.lower() in query_lower: | |
| return df.loc[df[col].idxmax()].to_dict() | |
| elif "average" in query_lower or "mean" in query_lower: | |
| for col in df.select_dtypes(include=['number']).columns: | |
| if col.lower() in query_lower: | |
| return f"Average {col}: {df[col].mean():.2f}" | |
| elif "total" in query_lower or "sum" in query_lower: | |
| for col in df.select_dtypes(include=['number']).columns: | |
| if col.lower() in query_lower: | |
| return f"Total {col}: {df[col].sum():.2f}" | |
| # Use TAPEX for more complex queries | |
| with torch.no_grad(): | |
| encoding = tokenizer( | |
| table=df.astype(str), | |
| query=query, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ) | |
| outputs = model.generate(**encoding) | |
| answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| processing_time = time.time() - start_time | |
| return f"Answer: {answer} (Processing time: {processing_time:.2f}s)" | |
| except Exception as e: | |
| return f"Error processing query: {str(e)}" | |
| def main(): | |
| st.title("XML Data Query System") | |
| st.write("Upload your XML data and ask questions about it!") | |
| # Initialize session state for XML input and query if not exists | |
| if 'xml_input' not in st.session_state: | |
| st.session_state.xml_input = "" | |
| if 'current_query' not in st.session_state: | |
| st.session_state.current_query = "" | |
| # Load model | |
| with st.spinner("Loading TAPEX model... (this may take a few moments)"): | |
| tokenizer, model = load_model() | |
| if tokenizer is None or model is None: | |
| st.error("Failed to load the model. Please refresh the page.") | |
| return | |
| # XML Input | |
| xml_input = st.text_area( | |
| "Enter your XML data here:", | |
| value=st.session_state.xml_input, | |
| height=200, | |
| help="Paste your XML data here. Make sure it's properly formatted." | |
| ) | |
| # Sample XML button | |
| if st.button("Load Sample XML"): | |
| st.session_state.xml_input = """<?xml version="1.0" encoding="UTF-8"?> | |
| <data> | |
| <records> | |
| <record> | |
| <company>Apple</company> | |
| <revenue>365.7</revenue> | |
| <employees>147000</employees> | |
| <year>2021</year> | |
| </record> | |
| <record> | |
| <company>Microsoft</company> | |
| <revenue>168.1</revenue> | |
| <employees>181000</employees> | |
| <year>2021</year> | |
| </record> | |
| <record> | |
| <company>Amazon</company> | |
| <revenue>386.1</revenue> | |
| <employees>1608000</employees> | |
| <year>2021</year> | |
| </record> | |
| </records> | |
| </data>""" | |
| st.rerun() | |
| if xml_input: | |
| df, error = parse_xml_to_dataframe(xml_input) | |
| if error: | |
| st.error(error) | |
| else: | |
| st.success("XML parsed successfully!") | |
| # Display DataFrame | |
| st.subheader("Parsed Data:") | |
| st.dataframe(df) | |
| # Query input | |
| query = st.text_input( | |
| "Enter your question about the data:", | |
| value=st.session_state.current_query, | |
| help="Example: 'Which company has the highest revenue?'" | |
| ) | |
| # Process query | |
| if query: | |
| with st.spinner("Processing query..."): | |
| result = process_query(tokenizer, model, df, query) | |
| st.write(result) | |
| # Sample queries | |
| st.subheader("Sample Questions (Click to use):") | |
| sample_queries = [ | |
| "Which company has the highest revenue?", | |
| "What is the average revenue of all companies?", | |
| "How many employees does Microsoft have?", | |
| "Which company has the most employees?", | |
| "What is the total revenue of all companies?" | |
| ] | |
| # Create columns for sample query buttons | |
| cols = st.columns(len(sample_queries)) | |
| for idx, (col, sample_query) in enumerate(zip(cols, sample_queries)): | |
| with col: | |
| if st.button(f"Query {idx + 1}", help=sample_query, key=f"query_btn_{idx}"): | |
| st.session_state.current_query = sample_query | |
| st.rerun() | |
| # Display the sample queries as text for reference | |
| with st.expander("View all sample questions"): | |
| for idx, query in enumerate(sample_queries, 1): | |
| st.write(f"{idx}. {query}") | |
| if __name__ == "__main__": | |
| main() | |