HeartInsight / nlp.py
AmiKim's picture
Update nlp.py
b0d3eaa verified
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import re
import os
import json
from datetime import datetime
from io import StringIO
def clean_kakao_message(text):
# Remove timestamps, system messages, etc.
if isinstance(text, str):
# Remove photo, video attachments
text = re.sub(r'\[Photo\]|\[Emoticon\]|\[Video\]|\[File\]', '', text)
# Remove URLs
text = re.sub(r'https?://\S+|www\.\S+', '', text)
# Remove other non-text content indicators
text = re.sub(r'\[Shop\]|\[Map\]', '', text)
return text.strip()
return ""
def analyze_sentiment(text, model, tokenizer):
if not text or len(text.strip()) < 2: # Skip empty or very short texts
return None
# Encode the text
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
# Get model prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits
probabilities = torch.nn.functional.softmax(predictions, dim=1)
# Get the predicted class and confidence
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Map class index to label using the model's id2label mapping
sentiment = model.config.id2label[predicted_class]
return {"sentiment": sentiment, "confidence": confidence, "text": text}
async def analyze_kakao_csv(file, model, tokenizer):
# 파일 내용 읽기 (bytes)
#contents = await file.read()
# 1차: utf-8 시도
try:
df = pd.read_csv(StringIO(file.decode("utf-8")))
except UnicodeDecodeError:
# 2차: cp949 시도
df = pd.read_csv(StringIO(file.decode("cp949")))
except Exception as e:
print(f"Error reading CSV: {e}")
return None
# Detect the structure of the CSV
print("CSV file structure:", df.columns.tolist())
# Try to identify message column and timestamp column
message_col = None
timestamp_col = None
possible_cols = ['Text', 'Message', 'Content', 'text', 'message', 'content']
possible_time_cols = ['Date', 'Time', 'Timestamp', 'date', 'time', 'timestamp']
for col in possible_cols:
if col in df.columns:
message_col = col
break
for col in possible_time_cols:
if col in df.columns:
timestamp_col = col
break
if not message_col:
# Try to guess which column contains the message content
for col in df.columns:
if df[col].dtype == 'object' and df[col].str.len().mean() > 10:
message_col = col
break
if not message_col:
print("Could not find a column containing message content.")
return
print(f"Using '{message_col}' as the message column.")
if timestamp_col:
print(f"Using '{timestamp_col}' as the timestamp column.")
# Clean messages
df['cleaned_message'] = df[message_col].apply(clean_kakao_message)
# Analyze sentiment for each message
results = []
print(f"Analyzing {len(df)} messages...")
for idx, row in tqdm(df.iterrows(), total=len(df)):
message = row['cleaned_message']
if not message or len(message.strip()) < 2: # Skip empty or very short texts
continue
# Get timestamp if available
timestamp = row[timestamp_col] if timestamp_col else datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Use the pipeline to analyze sentiment
sentiment_result = analyze_sentiment(message, model, tokenizer)
if sentiment_result:
sentiment_result['timestamp'] = timestamp
results.append(sentiment_result)
# Create results DataFrame
results_df = pd.DataFrame(results)
return results_df
def get_json_result(results_df, model_name="KCElectra"):
if results_df is None or len(results_df) == 0:
print("No results to analyze.")
return
# Generate timestamp for unique filenames
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Count each class and draw bar chart with dynamic colors
sentiment_counts = results_df['sentiment'].value_counts()
# Create JSON output
json_output = {
"model_name": model_name,
"analysis_timestamp": timestamp,
"total_messages": len(results_df),
"sentiment_distribution": sentiment_counts.to_dict(),
"average_confidence": results_df.groupby('sentiment')['confidence'].mean().to_dict(),
"messages": [
{
"text": row['text'],
"sentiment": row['sentiment'],
"confidence": float(row['confidence']),
"timestamp": row['timestamp']
}
for _, row in results_df.iterrows()
]
}
return json_output