import torch import requests from bs4 import BeautifulSoup from urllib.parse import quote def get_synonyms_from_wordsisters(word: str) -> list[str]: encoded_word = quote(word) url = f"https://wordsisters.com/api/ai/{word}" headers = { "User-Agent": "Mozilla/5.0", "Referer": f"https://wordsisters.com/search/{encoded_word}", } try: response = requests.get(url, headers=headers) response.raise_for_status() data = response.json() synonyms = data.get("result", {}).get("synonyms", []) return synonyms except Exception as e: print(f"Error fetching synonyms: {e}") return [] def extract_synonyms_from_html(html: str) -> list[str]: try: soup = BeautifulSoup(html, "html.parser") synonyms = [] for tag in soup.select(".link_relate"): text = tag.get_text(strip=True) if text and text not in synonyms: synonyms.append(text) print(f"Extracted synonyms: {synonyms}") return synonyms except Exception as e: print(f"Error parsing HTML: {e}") return [] def get_synonyms_from_daum(word: str) -> list[str]: try: headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"} params = {"q": word} response = requests.get( "https://dic.daum.net/search.do", params=params, headers=headers ) response.raise_for_status() return extract_synonyms_from_html(response.text) except Exception as e: print(f"Error fetching from Daum: {e}") def max_logit(tensor, symDict, tokenizer): found = [] counter = 0 size = len(symDict) stop = False for i in range(0, 32000): for j in range(0, size): if str(tokenizer.decode(tensor[1][0][i])) == symDict[j]: found.append(symDict[j]) counter += 1 break if counter >= 3: break return found def recommendWord(user_sentence, MaskWord, tokenizer, model): inputs = tokenizer(user_sentence, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero( as_tuple=True )[0] symDict = get_synonyms_from_wordsisters(MaskWord) ts = torch.sort(logits[0, mask_token_index], dim=-1, descending=True) found = max_logit(ts, symDict, tokenizer) return found