test / backend /recommendWord.py
uuuy5615's picture
Upload 37 files
5ea2b9d verified
import torch
import requests
from bs4 import BeautifulSoup
from urllib.parse import quote
def get_synonyms_from_wordsisters(word: str) -> list[str]:
encoded_word = quote(word)
url = f"https://wordsisters.com/api/ai/{word}"
headers = {
"User-Agent": "Mozilla/5.0",
"Referer": f"https://wordsisters.com/search/{encoded_word}",
}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
synonyms = data.get("result", {}).get("synonyms", [])
return synonyms
except Exception as e:
print(f"Error fetching synonyms: {e}")
return []
def extract_synonyms_from_html(html: str) -> list[str]:
try:
soup = BeautifulSoup(html, "html.parser")
synonyms = []
for tag in soup.select(".link_relate"):
text = tag.get_text(strip=True)
if text and text not in synonyms:
synonyms.append(text)
print(f"Extracted synonyms: {synonyms}")
return synonyms
except Exception as e:
print(f"Error parsing HTML: {e}")
return []
def get_synonyms_from_daum(word: str) -> list[str]:
try:
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}
params = {"q": word}
response = requests.get(
"https://dic.daum.net/search.do", params=params, headers=headers
)
response.raise_for_status()
return extract_synonyms_from_html(response.text)
except Exception as e:
print(f"Error fetching from Daum: {e}")
def max_logit(tensor, symDict, tokenizer):
found = []
counter = 0
size = len(symDict)
stop = False
for i in range(0, 32000):
for j in range(0, size):
if str(tokenizer.decode(tensor[1][0][i])) == symDict[j]:
found.append(symDict[j])
counter += 1
break
if counter >= 3:
break
return found
def recommendWord(user_sentence, MaskWord, tokenizer, model):
inputs = tokenizer(user_sentence, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(
as_tuple=True
)[0]
symDict = get_synonyms_from_wordsisters(MaskWord)
ts = torch.sort(logits[0, mask_token_index], dim=-1, descending=True)
found = max_logit(ts, symDict, tokenizer)
return found