import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import numpy as np def preprocess_function(examples, tokenizer, max_length=512): """ Preprocess text data for the steel material classification model Args: examples: Dataset examples containing text tokenizer: Tokenizer instance max_length: Maximum sequence length Returns: dict: Tokenized inputs """ # Tokenize the texts result = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=max_length, return_tensors="pt" ) return result def postprocess_function(predictions, id2label): """ Postprocess model predictions Args: predictions: Raw model predictions id2label: Mapping from label IDs to label names Returns: dict: Processed predictions with labels and probabilities """ # Convert logits to probabilities probabilities = torch.nn.functional.softmax(torch.tensor(predictions), dim=-1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, k=5, dim=1) results = [] for i in range(len(predictions)): sample_results = [] for j in range(5): label_id = top_indices[i][j].item() probability = top_probs[i][j].item() label = id2label[label_id] sample_results.append({ "label": label, "label_id": label_id, "probability": probability }) results.append(sample_results) return results def validate_input(text): """ Validate input text for classification Args: text: Input text to validate Returns: bool: True if valid, False otherwise """ if not isinstance(text, str): return False if len(text.strip()) == 0: return False if len(text) > 1000: # Reasonable limit for steel material descriptions return False return True def clean_text(text): """ Clean and normalize input text Args: text: Raw input text Returns: str: Cleaned text """ # Remove extra whitespace text = " ".join(text.split()) # Normalize Korean characters (if needed) # Add any specific text cleaning rules here return text.strip() # Example usage if __name__ == "__main__": # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(".") # Example preprocessing example_texts = [ "철광석을 고로에서 환원하여 선철을 제조하는 과정", "천연가스를 연료로 사용하여 고로를 가열", "석회석을 첨가하여 슬래그를 형성" ] # Clean and validate texts cleaned_texts = [] for text in example_texts: if validate_input(text): cleaned_text = clean_text(text) cleaned_texts.append(cleaned_text) # Preprocess examples = {"text": cleaned_texts} tokenized = preprocess_function(examples, tokenizer) print("=== Preprocessing Example ===") print(f"Input texts: {cleaned_texts}") print(f"Tokenized shape: {tokenized['input_ids'].shape}") print(f"Attention mask shape: {tokenized['attention_mask'].shape}")