| import torch | |
| import requests | |
| from bs4 import BeautifulSoup | |
| from urllib.parse import quote | |
| def get_synonyms_from_wordsisters(word: str) -> list[str]: | |
| encoded_word = quote(word) | |
| url = f"https://wordsisters.com/api/ai/{word}" | |
| headers = { | |
| "User-Agent": "Mozilla/5.0", | |
| "Referer": f"https://wordsisters.com/search/{encoded_word}", | |
| } | |
| try: | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| data = response.json() | |
| synonyms = data.get("result", {}).get("synonyms", []) | |
| return synonyms | |
| except Exception as e: | |
| print(f"Error fetching synonyms: {e}") | |
| return [] | |
| def extract_synonyms_from_html(html: str) -> list[str]: | |
| try: | |
| soup = BeautifulSoup(html, "html.parser") | |
| synonyms = [] | |
| for tag in soup.select(".link_relate"): | |
| text = tag.get_text(strip=True) | |
| if text and text not in synonyms: | |
| synonyms.append(text) | |
| print(f"Extracted synonyms: {synonyms}") | |
| return synonyms | |
| except Exception as e: | |
| print(f"Error parsing HTML: {e}") | |
| return [] | |
| def get_synonyms_from_daum(word: str) -> list[str]: | |
| try: | |
| headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"} | |
| params = {"q": word} | |
| response = requests.get( | |
| "https://dic.daum.net/search.do", params=params, headers=headers | |
| ) | |
| response.raise_for_status() | |
| return extract_synonyms_from_html(response.text) | |
| except Exception as e: | |
| print(f"Error fetching from Daum: {e}") | |
| def max_logit(tensor, symDict, tokenizer): | |
| found = [] | |
| counter = 0 | |
| size = len(symDict) | |
| stop = False | |
| for i in range(0, 32000): | |
| for j in range(0, size): | |
| if str(tokenizer.decode(tensor[1][0][i])) == symDict[j]: | |
| found.append(symDict[j]) | |
| counter += 1 | |
| break | |
| if counter >= 3: | |
| break | |
| return found | |
| def recommendWord(user_sentence, MaskWord, tokenizer, model): | |
| inputs = tokenizer(user_sentence, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero( | |
| as_tuple=True | |
| )[0] | |
| symDict = get_synonyms_from_wordsisters(MaskWord) | |
| ts = torch.sort(logits[0, mask_token_index], dim=-1, descending=True) | |
| found = max_logit(ts, symDict, tokenizer) | |
| return found | |