flud / preprocessor.py
Halfotter's picture
Upload 16 files
14ebc37 verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
def preprocess_function(examples, tokenizer, max_length=512):
"""
Preprocess text data for the steel material classification model
Args:
examples: Dataset examples containing text
tokenizer: Tokenizer instance
max_length: Maximum sequence length
Returns:
dict: Tokenized inputs
"""
# Tokenize the texts
result = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt"
)
return result
def postprocess_function(predictions, id2label):
"""
Postprocess model predictions
Args:
predictions: Raw model predictions
id2label: Mapping from label IDs to label names
Returns:
dict: Processed predictions with labels and probabilities
"""
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(torch.tensor(predictions), dim=-1)
# Get top predictions
top_probs, top_indices = torch.topk(probabilities, k=5, dim=1)
results = []
for i in range(len(predictions)):
sample_results = []
for j in range(5):
label_id = top_indices[i][j].item()
probability = top_probs[i][j].item()
label = id2label[label_id]
sample_results.append({
"label": label,
"label_id": label_id,
"probability": probability
})
results.append(sample_results)
return results
def validate_input(text):
"""
Validate input text for classification
Args:
text: Input text to validate
Returns:
bool: True if valid, False otherwise
"""
if not isinstance(text, str):
return False
if len(text.strip()) == 0:
return False
if len(text) > 1000: # Reasonable limit for steel material descriptions
return False
return True
def clean_text(text):
"""
Clean and normalize input text
Args:
text: Raw input text
Returns:
str: Cleaned text
"""
# Remove extra whitespace
text = " ".join(text.split())
# Normalize Korean characters (if needed)
# Add any specific text cleaning rules here
return text.strip()
# Example usage
if __name__ == "__main__":
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(".")
# Example preprocessing
example_texts = [
"์ฒ ๊ด‘์„์„ ๊ณ ๋กœ์—์„œ ํ™˜์›ํ•˜์—ฌ ์„ ์ฒ ์„ ์ œ์กฐํ•˜๋Š” ๊ณผ์ •",
"์ฒœ์—ฐ๊ฐ€์Šค๋ฅผ ์—ฐ๋ฃŒ๋กœ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ ๋กœ๋ฅผ ๊ฐ€์—ด",
"์„ํšŒ์„์„ ์ฒจ๊ฐ€ํ•˜์—ฌ ์Šฌ๋ž˜๊ทธ๋ฅผ ํ˜•์„ฑ"
]
# Clean and validate texts
cleaned_texts = []
for text in example_texts:
if validate_input(text):
cleaned_text = clean_text(text)
cleaned_texts.append(cleaned_text)
# Preprocess
examples = {"text": cleaned_texts}
tokenized = preprocess_function(examples, tokenizer)
print("=== Preprocessing Example ===")
print(f"Input texts: {cleaned_texts}")
print(f"Tokenized shape: {tokenized['input_ids'].shape}")
print(f"Attention mask shape: {tokenized['attention_mask'].shape}")