File size: 2,588 Bytes
5ea2b9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import torch
import requests
from bs4 import BeautifulSoup
from urllib.parse import quote
def get_synonyms_from_wordsisters(word: str) -> list[str]:
encoded_word = quote(word)
url = f"https://wordsisters.com/api/ai/{word}"
headers = {
"User-Agent": "Mozilla/5.0",
"Referer": f"https://wordsisters.com/search/{encoded_word}",
}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
synonyms = data.get("result", {}).get("synonyms", [])
return synonyms
except Exception as e:
print(f"Error fetching synonyms: {e}")
return []
def extract_synonyms_from_html(html: str) -> list[str]:
try:
soup = BeautifulSoup(html, "html.parser")
synonyms = []
for tag in soup.select(".link_relate"):
text = tag.get_text(strip=True)
if text and text not in synonyms:
synonyms.append(text)
print(f"Extracted synonyms: {synonyms}")
return synonyms
except Exception as e:
print(f"Error parsing HTML: {e}")
return []
def get_synonyms_from_daum(word: str) -> list[str]:
try:
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
params = {"q": word}
response = requests.get(
"https://dic.daum.net/search.do", params=params, headers=headers
)
response.raise_for_status()
return extract_synonyms_from_html(response.text)
except Exception as e:
print(f"Error fetching from Daum: {e}")
def max_logit(tensor, symDict, tokenizer):
found = []
counter = 0
size = len(symDict)
stop = False
for i in range(0, 32000):
for j in range(0, size):
if str(tokenizer.decode(tensor[1][0][i])) == symDict[j]:
found.append(symDict[j])
counter += 1
break
if counter >= 3:
break
return found
def recommendWord(user_sentence, MaskWord, tokenizer, model):
inputs = tokenizer(user_sentence, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(
as_tuple=True
)[0]
symDict = get_synonyms_from_wordsisters(MaskWord)
ts = torch.sort(logits[0, mask_token_index], dim=-1, descending=True)
found = max_logit(ts, symDict, tokenizer)
return found
|