Upload folder using huggingface_hub
Browse files- README.md +2 -3
- app.py +48 -0
- bert_imdb_sentiment.pth +3 -0
- data/imdb_data.py +27 -0
- hub.py +7 -0
- model/sentiment_model.py +23 -0
- predict.py +81 -0
- schemas/sentiment.py +13 -0
- services/inference.py +50 -0
- streamlit_app.py +61 -0
- test.py +16 -0
- train.py +105 -0
README.md
CHANGED
|
@@ -1,3 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
| 1 |
+
# imdb_sentiment_analysis
|
| 2 |
+
基于 BERT 的 IMDB 电影评论情感分析 FastAPI 服务,已经使用streamlit简单实现了前端
|
|
|
app.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py(现在)
|
| 2 |
+
from fastapi import FastAPI
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import BertTokenizer
|
| 5 |
+
from model.sentiment_model import SentimentAnalysisModel
|
| 6 |
+
from schemas.sentiment import SentimentRequest, SentimentResponse
|
| 7 |
+
from services.inference import predict_sentiment
|
| 8 |
+
from schemas.sentiment import BatchSentimentRequest
|
| 9 |
+
from services.inference import batch_predict
|
| 10 |
+
from fastapi.concurrency import run_in_threadpool
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
app = FastAPI()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@app.on_event("startup")
|
| 17 |
+
def startup_event():
|
| 18 |
+
global tokenizer, model, device
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
|
| 21 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 22 |
+
model = SentimentAnalysisModel("bert-base-uncased")
|
| 23 |
+
model.load_state_dict(torch.load("bert_imdb_sentiment.pth", map_location=device))
|
| 24 |
+
model.to(device)
|
| 25 |
+
model.eval()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.post("/predict", response_model=SentimentResponse)
|
| 29 |
+
async def predict_api(req: SentimentRequest):
|
| 30 |
+
label, conf = await run_in_threadpool(
|
| 31 |
+
predict_sentiment, req.text, tokenizer, model, device
|
| 32 |
+
)
|
| 33 |
+
return SentimentResponse(label=label, confidence=conf)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@app.post("/predict_batch")
|
| 38 |
+
async def predict_batch_api(req: BatchSentimentRequest):
|
| 39 |
+
results = await run_in_threadpool(
|
| 40 |
+
batch_predict, req.texts, tokenizer, model, device
|
| 41 |
+
)
|
| 42 |
+
return results
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@app.get("/health")
|
| 47 |
+
def health():
|
| 48 |
+
return {"status": "ok"}
|
bert_imdb_sentiment.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:399289093f3a22c71110cfcf4f558f8611d025711e74e2de6b39de9236ac04b2
|
| 3 |
+
size 438019196
|
data/imdb_data.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class IMDBDataset(Dataset):
|
| 7 |
+
def __init__(self, split, tokenizer, max_length=256):
|
| 8 |
+
print(f"Loading IMDB {split} dataset...")
|
| 9 |
+
self.dataset = load_dataset("imdb")[split]
|
| 10 |
+
print(f"IMDB {split} loaded.")
|
| 11 |
+
self.encodings = tokenizer(
|
| 12 |
+
self.dataset["text"],
|
| 13 |
+
truncation=True,
|
| 14 |
+
padding=True,
|
| 15 |
+
max_length=max_length
|
| 16 |
+
)
|
| 17 |
+
self.labels = self.dataset["label"]
|
| 18 |
+
|
| 19 |
+
def __len__(self):
|
| 20 |
+
return len(self.labels)
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, idx):
|
| 23 |
+
return {
|
| 24 |
+
"input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long),
|
| 25 |
+
"attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long),
|
| 26 |
+
"labels": torch.tensor(self.labels[idx], dtype=torch.long)
|
| 27 |
+
}
|
hub.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import login, upload_folder
|
| 2 |
+
|
| 3 |
+
# Login with your Hugging Face token (embedded)
|
| 4 |
+
login()
|
| 5 |
+
|
| 6 |
+
# Push your model files
|
| 7 |
+
upload_folder(folder_path=".", repo_id="ikkbor/bert_imdb_sentiment", repo_type="model")
|
model/sentiment_model.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from transformers import BertModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SentimentAnalysisModel(nn.Module):
|
| 6 |
+
def __init__(self, pretrained_model_name="bert-base-uncased"):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.bert = BertModel.from_pretrained(pretrained_model_name)
|
| 10 |
+
self.dropout = nn.Dropout(0.3)
|
| 11 |
+
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
|
| 12 |
+
|
| 13 |
+
def forward(self, input_ids, attention_mask):
|
| 14 |
+
outputs = self.bert(
|
| 15 |
+
input_ids=input_ids,
|
| 16 |
+
attention_mask=attention_mask
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
pooled_output = outputs.pooler_output
|
| 20 |
+
pooled_output = self.dropout(pooled_output)
|
| 21 |
+
logits = self.classifier(pooled_output)
|
| 22 |
+
|
| 23 |
+
return logits
|
predict.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
from model.sentiment_model import SentimentAnalysisModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# 设备
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
|
| 9 |
+
# 1. 加载 tokenizer
|
| 10 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 11 |
+
|
| 12 |
+
# 2. 加载模型结构
|
| 13 |
+
model = SentimentAnalysisModel("bert-base-uncased")
|
| 14 |
+
|
| 15 |
+
# 3. 加载训练好的权重
|
| 16 |
+
model.load_state_dict(
|
| 17 |
+
torch.load("bert_imdb_sentiment.pth", map_location=device)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
model.to(device)
|
| 21 |
+
model.eval() # ⚠️ 非常重要
|
| 22 |
+
|
| 23 |
+
print("Model loaded successfully.")
|
| 24 |
+
|
| 25 |
+
def predict_sentiment(text):
|
| 26 |
+
inputs = tokenizer(
|
| 27 |
+
text,
|
| 28 |
+
padding="max_length",
|
| 29 |
+
truncation=True,
|
| 30 |
+
max_length=256,
|
| 31 |
+
return_tensors="pt"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
input_ids = inputs["input_ids"].to(device)
|
| 35 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 36 |
+
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
outputs = model(input_ids, attention_mask)
|
| 39 |
+
probs = torch.softmax(outputs, dim=1)
|
| 40 |
+
pred = torch.argmax(probs, dim=1).item()
|
| 41 |
+
|
| 42 |
+
label_map = {0: "Negative 😡", 1: "Positive 😊"}
|
| 43 |
+
return label_map[pred], probs[0][pred].item()
|
| 44 |
+
|
| 45 |
+
def batch_predict(texts):
|
| 46 |
+
inputs = tokenizer(
|
| 47 |
+
texts,
|
| 48 |
+
padding=True,
|
| 49 |
+
truncation=True,
|
| 50 |
+
max_length=256,
|
| 51 |
+
return_tensors="pt"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
input_ids = inputs["input_ids"].to(device)
|
| 55 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
outputs = model(input_ids, attention_mask)
|
| 59 |
+
preds = torch.argmax(outputs, dim=1)
|
| 60 |
+
|
| 61 |
+
return preds.cpu().tolist()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
# text = "This movie was absolutely amazing, I loved it!"
|
| 66 |
+
texts = [
|
| 67 |
+
"This movie is terrible.",
|
| 68 |
+
"I really enjoyed this film!",
|
| 69 |
+
"Not bad, but could be better."
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
results = batch_predict(texts)
|
| 73 |
+
print(results) # [0, 1, 1]
|
| 74 |
+
# | 数值 | 含义 |
|
| 75 |
+
# | -- | -------- |
|
| 76 |
+
# | 0 | Negative |
|
| 77 |
+
# | 1 | Positive |
|
| 78 |
+
|
| 79 |
+
# label, confidence = predict_sentiment(texts)
|
| 80 |
+
# print(f"Text: {text}")
|
| 81 |
+
# print(f"Prediction: {label}, confidence: {confidence:.4f}")
|
schemas/sentiment.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# schemas/sentiment.py
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
class SentimentRequest(BaseModel):
|
| 6 |
+
text: str
|
| 7 |
+
|
| 8 |
+
class SentimentResponse(BaseModel):
|
| 9 |
+
label: str
|
| 10 |
+
confidence: float
|
| 11 |
+
|
| 12 |
+
class BatchSentimentRequest(BaseModel):
|
| 13 |
+
texts: List[str]
|
services/inference.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# services/inference.py
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def predict_sentiment(text, tokenizer, model, device):
|
| 5 |
+
inputs = tokenizer(
|
| 6 |
+
text,
|
| 7 |
+
padding="max_length",
|
| 8 |
+
truncation=True,
|
| 9 |
+
max_length=256,
|
| 10 |
+
return_tensors="pt"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
input_ids = inputs["input_ids"].to(device)
|
| 14 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 15 |
+
|
| 16 |
+
with torch.no_grad():
|
| 17 |
+
outputs = model(input_ids, attention_mask)
|
| 18 |
+
probs = torch.softmax(outputs, dim=1)
|
| 19 |
+
pred = torch.argmax(probs, dim=1).item()
|
| 20 |
+
|
| 21 |
+
label_map = {0: "Negative", 1: "Positive"}
|
| 22 |
+
return label_map[pred], probs[0][pred].item()
|
| 23 |
+
|
| 24 |
+
def batch_predict(texts, tokenizer, model, device):
|
| 25 |
+
inputs = tokenizer(
|
| 26 |
+
texts,
|
| 27 |
+
padding=True,
|
| 28 |
+
truncation=True,
|
| 29 |
+
max_length=256,
|
| 30 |
+
return_tensors="pt"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
input_ids = inputs["input_ids"].to(device)
|
| 34 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 35 |
+
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
outputs = model(input_ids, attention_mask)
|
| 38 |
+
probs = torch.softmax(outputs, dim=1)
|
| 39 |
+
preds = torch.argmax(probs, dim=1)
|
| 40 |
+
|
| 41 |
+
label_map = {0: "Negative", 1: "Positive"}
|
| 42 |
+
|
| 43 |
+
return [
|
| 44 |
+
{
|
| 45 |
+
"text": text,
|
| 46 |
+
"label": label_map[p.item()],
|
| 47 |
+
"confidence": probs[i][p].item()
|
| 48 |
+
}
|
| 49 |
+
for i, (text, p) in enumerate(zip(texts, preds))
|
| 50 |
+
]
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<<<<<<< HEAD
|
| 2 |
+
# streamlit_app.py
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
st.title("IMDB 情感分析 Demo")
|
| 7 |
+
|
| 8 |
+
# 单条文本输入
|
| 9 |
+
text = st.text_area("输入文本", "This movie was amazing!")
|
| 10 |
+
|
| 11 |
+
if st.button("预测情感"):
|
| 12 |
+
response = requests.post(
|
| 13 |
+
"http://127.0.0.1:8000/predict",
|
| 14 |
+
json={"text": text}
|
| 15 |
+
)
|
| 16 |
+
result = response.json()
|
| 17 |
+
st.write(f"情感:{result['label']}")
|
| 18 |
+
st.write(f"置信度:{result['confidence']:.2f}")
|
| 19 |
+
|
| 20 |
+
# 批量文本输入
|
| 21 |
+
batch_texts = st.text_area("批量文本(每行一条)", "I loved it\nNot good")
|
| 22 |
+
if st.button("批量预测"):
|
| 23 |
+
texts_list = [line.strip() for line in batch_texts.split("\n") if line.strip()]
|
| 24 |
+
response = requests.post(
|
| 25 |
+
"http://127.0.0.1:8000/predict_batch",
|
| 26 |
+
json={"texts": texts_list}
|
| 27 |
+
)
|
| 28 |
+
results = response.json()
|
| 29 |
+
for r in results:
|
| 30 |
+
st.write(f"{r['text']} → {r['label']} ({r['confidence']:.2f})")
|
| 31 |
+
=======
|
| 32 |
+
# streamlit_app.py
|
| 33 |
+
import streamlit as st
|
| 34 |
+
import requests
|
| 35 |
+
|
| 36 |
+
st.title("IMDB 情感分析 Demo")
|
| 37 |
+
|
| 38 |
+
# 单条文本输入
|
| 39 |
+
text = st.text_area("输入文本", "This movie was amazing!")
|
| 40 |
+
|
| 41 |
+
if st.button("预测情感"):
|
| 42 |
+
response = requests.post(
|
| 43 |
+
"http://127.0.0.1:8000/predict",
|
| 44 |
+
json={"text": text}
|
| 45 |
+
)
|
| 46 |
+
result = response.json()
|
| 47 |
+
st.write(f"情感:{result['label']}")
|
| 48 |
+
st.write(f"置信度:{result['confidence']:.2f}")
|
| 49 |
+
|
| 50 |
+
# 批量文本输入
|
| 51 |
+
batch_texts = st.text_area("批量文本(每行一条)", "I loved it\nNot good")
|
| 52 |
+
if st.button("批量预测"):
|
| 53 |
+
texts_list = [line.strip() for line in batch_texts.split("\n") if line.strip()]
|
| 54 |
+
response = requests.post(
|
| 55 |
+
"http://127.0.0.1:8000/predict_batch",
|
| 56 |
+
json={"texts": texts_list}
|
| 57 |
+
)
|
| 58 |
+
results = response.json()
|
| 59 |
+
for r in results:
|
| 60 |
+
st.write(f"{r['text']} → {r['label']} ({r['confidence']:.2f})")
|
| 61 |
+
>>>>>>> 9ef5a78 (首次提交)
|
test.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import aiohttp
|
| 3 |
+
|
| 4 |
+
async def main():
|
| 5 |
+
async with aiohttp.ClientSession() as session:
|
| 6 |
+
tasks = []
|
| 7 |
+
for i in range(5): # 同时5个请求
|
| 8 |
+
tasks.append(session.post(
|
| 9 |
+
"http://localhost:8000/predict",
|
| 10 |
+
json={"text": f"This is test {i}"}
|
| 11 |
+
))
|
| 12 |
+
responses = await asyncio.gather(*tasks)
|
| 13 |
+
for r in responses:
|
| 14 |
+
print(await r.json())
|
| 15 |
+
|
| 16 |
+
asyncio.run(main())
|
train.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from transformers import BertTokenizer
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from sklearn.metrics import accuracy_score
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from data.imdb_data import IMDBDataset
|
| 10 |
+
from model.sentiment_model import SentimentAnalysisModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def evaluate(model, dataloader, device):
|
| 14 |
+
model.eval()
|
| 15 |
+
preds, labels = [], []
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
for batch in tqdm(dataloader, desc="Evaluating"):
|
| 19 |
+
input_ids = batch["input_ids"].to(device)
|
| 20 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 21 |
+
batch_labels = batch["labels"].to(device)
|
| 22 |
+
|
| 23 |
+
outputs = model(input_ids, attention_mask)
|
| 24 |
+
predictions = torch.argmax(outputs, dim=1)
|
| 25 |
+
|
| 26 |
+
preds.extend(predictions.cpu().tolist())
|
| 27 |
+
labels.extend(batch_labels.cpu().tolist())
|
| 28 |
+
|
| 29 |
+
acc = accuracy_score(labels, preds)
|
| 30 |
+
print(f"Validation Accuracy: {acc:.4f}")
|
| 31 |
+
return acc
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def train():
|
| 35 |
+
# ================== 超参数 ==================
|
| 36 |
+
model_name = "bert-base-uncased"
|
| 37 |
+
batch_size = 8
|
| 38 |
+
max_length = 256
|
| 39 |
+
lr = 2e-5
|
| 40 |
+
epochs = 3
|
| 41 |
+
# ===========================================
|
| 42 |
+
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
print("Using device:", device)
|
| 45 |
+
|
| 46 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 47 |
+
|
| 48 |
+
train_dataset = IMDBDataset("train", tokenizer, max_length)
|
| 49 |
+
test_dataset = IMDBDataset("test", tokenizer, max_length)
|
| 50 |
+
|
| 51 |
+
train_loader = DataLoader(
|
| 52 |
+
train_dataset,
|
| 53 |
+
batch_size=batch_size,
|
| 54 |
+
shuffle=True,
|
| 55 |
+
num_workers=4,
|
| 56 |
+
pin_memory=True
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
test_loader = DataLoader(
|
| 60 |
+
test_dataset,
|
| 61 |
+
batch_size=batch_size,
|
| 62 |
+
shuffle=False,
|
| 63 |
+
num_workers=4,
|
| 64 |
+
pin_memory=True
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
model = SentimentAnalysisModel(model_name).to(device)
|
| 68 |
+
print("Model device:", next(model.parameters()).device)
|
| 69 |
+
|
| 70 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr)
|
| 71 |
+
criterion = nn.CrossEntropyLoss()
|
| 72 |
+
|
| 73 |
+
# ================== 训练 ==================
|
| 74 |
+
for epoch in range(epochs):
|
| 75 |
+
model.train()
|
| 76 |
+
total_loss = 0
|
| 77 |
+
|
| 78 |
+
loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
| 79 |
+
for step, batch in enumerate(loop):
|
| 80 |
+
input_ids = batch["input_ids"].to(device)
|
| 81 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 82 |
+
labels = batch["labels"].to(device)
|
| 83 |
+
|
| 84 |
+
optimizer.zero_grad()
|
| 85 |
+
outputs = model(input_ids, attention_mask)
|
| 86 |
+
loss = criterion(outputs, labels)
|
| 87 |
+
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
|
| 91 |
+
total_loss += loss.item()
|
| 92 |
+
loop.set_postfix(loss=loss.item())
|
| 93 |
+
|
| 94 |
+
avg_loss = total_loss / len(train_loader)
|
| 95 |
+
print(f"\nEpoch {epoch + 1} Training Loss: {avg_loss:.4f}")
|
| 96 |
+
|
| 97 |
+
evaluate(model, test_loader, device)
|
| 98 |
+
|
| 99 |
+
# ================== 保存模型 ==================
|
| 100 |
+
torch.save(model.state_dict(), "bert_imdb_sentiment.pth")
|
| 101 |
+
print("Model saved.")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
train()
|