File size: 2,588 Bytes
5ea2b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import requests
from bs4 import BeautifulSoup
from urllib.parse import quote


def get_synonyms_from_wordsisters(word: str) -> list[str]:
    encoded_word = quote(word)
    url = f"https://wordsisters.com/api/ai/{word}"
    headers = {
        "User-Agent": "Mozilla/5.0",
        "Referer": f"https://wordsisters.com/search/{encoded_word}",
    }
    try:
        response = requests.get(url, headers=headers)
        response.raise_for_status()

        data = response.json()
        synonyms = data.get("result", {}).get("synonyms", [])
        return synonyms
    except Exception as e:
        print(f"Error fetching synonyms: {e}")
        return []


def extract_synonyms_from_html(html: str) -> list[str]:
    try:
        soup = BeautifulSoup(html, "html.parser")
        synonyms = []

        for tag in soup.select(".link_relate"):
            text = tag.get_text(strip=True)
            if text and text not in synonyms:
                synonyms.append(text)

        print(f"Extracted synonyms: {synonyms}")
        return synonyms
    except Exception as e:
        print(f"Error parsing HTML: {e}")
        return []


def get_synonyms_from_daum(word: str) -> list[str]:
    try:
        headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
        params = {"q": word}

        response = requests.get(
            "https://dic.daum.net/search.do", params=params, headers=headers
        )
        response.raise_for_status()

        return extract_synonyms_from_html(response.text)
    except Exception as e:
        print(f"Error fetching from Daum: {e}")


def max_logit(tensor, symDict, tokenizer):
    found = []
    counter = 0
    size = len(symDict)
    stop = False
    for i in range(0, 32000):
        for j in range(0, size):
            if str(tokenizer.decode(tensor[1][0][i])) == symDict[j]:
                found.append(symDict[j])
                counter += 1
                break
        if counter >= 3:
            break
    return found


def recommendWord(user_sentence, MaskWord, tokenizer, model):
    inputs = tokenizer(user_sentence, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(
        as_tuple=True
    )[0]
    symDict = get_synonyms_from_wordsisters(MaskWord)
    ts = torch.sort(logits[0, mask_token_index], dim=-1, descending=True)
    found = max_logit(ts, symDict, tokenizer)
    return found