File size: 3,927 Bytes
d3d4834
 
 
 
 
 
 
 
 
 
 
 
 
92e7042
d3d4834
 
 
86838b3
d3d4834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c06b2
 
 
 
 
 
 
 
 
 
 
 
 
 
d3d4834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f33192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import nltk
from tqdm import tqdm
import sys, os.path

nltk.download('punkt')

BASE_DIR = Path(__file__).resolve().parent.parent.parent
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))

from dependencies.Glimpse_tokenizer import glimpse_tokenizer

# === CONFIGURATION ===

MODEL_DIR = BASE_DIR / "scibert" / "scibert_topic" / "final_model"
DATA_DIR = BASE_DIR / "glimpse" / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "data" / "topic_scored"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# === Load model and tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# === Tokenize like GLIMPSE ===
# def tokenize_sentences(text: str) -> list:
#     # same tokenization as in the original glimpse code
#     text = text.replace('-----', '\n')
#     sentences = nltk.sent_tokenize(text)
#     sentences = [sentence for sentence in sentences if sentence != ""]
#     return sentences 


# === Label map (optional: for human-readable output) ===
id2label = {
    # 0: "Evaluative",
    # 1: "Structuring",
    # 2: "Request",
    # 3: "Fact",
    # 4: "Social",
    # 5: "Other",
    0: "Substance",
    1: "Clarity",
    2: "Soundness/Correctness",
    3: "Originality",
    4: "Motivation/Impact",
    5: "Meaningful Comparison",
    6: "Replicability",
    7: "NONE"  # This is used for sentences that do not match any specific topic
}

def predict_topic(sentences):
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=1).cpu().tolist()
    # Convert predictions to human-readable labels
    predictions = [id2label[pred] for pred in predictions]
    return predictions


def find_topic(start_year=2017, end_year=2021):
    for year in range(start_year, end_year + 1):
        print(f"Processing {year}...")
        input_path = DATA_DIR / f"all_reviews_{year}.csv"
        output_path = OUTPUT_DIR / f"topic_scored_reviews_{year}.csv"

        df = pd.read_csv(input_path)

        all_rows = []
        for _, row in tqdm(df.iterrows(), total=len(df)):
            review_id = row["id"]
            text = row["text"]
            sentences = glimpse_tokenizer(text)
            if not sentences:
                continue
            labels = predict_topic(sentences)
            for sentence, topic in zip(sentences, labels):
                all_rows.append({"id": review_id, "sentence": sentence, "topic": topic})

        output_df = pd.DataFrame(all_rows)
        output_df.to_csv(output_path, index=False)
        print(f"Saved topic-scored data to {output_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Run topic scoring with SciBERT")
    parser.add_argument("--start-year", type=int, default=None, help="Start year (default: auto-detect)")
    parser.add_argument("--end-year", type=int, default=None, help="End year (default: auto-detect)")
    args = parser.parse_args()

    if args.start_year is not None and args.end_year is not None:
        find_topic(start_year=args.start_year, end_year=args.end_year)
    else:
        import re
        available = sorted(
            int(re.search(r'all_reviews_(\d{4})\.csv', f.name).group(1))
            for f in DATA_DIR.glob("all_reviews_*.csv")
            if re.search(r'all_reviews_(\d{4})\.csv', f.name)
        )
        if available:
            find_topic(start_year=min(available), end_year=max(available))
        else:
            print("No data files found in", DATA_DIR)