| import argparse, re, itertools |
| import numpy as np |
| from sentence_transformers import SentenceTransformer, util |
|
|
| def candidates(text): |
| |
| words = re.findall(r"[A-Za-z][A-Za-z\-']+", text.lower()) |
| phrases = set() |
| for n in range(1, 4): |
| for i in range(len(words)-n+1): |
| phrase = " ".join(words[i:i+n]) |
| if len(phrase) > 3: |
| phrases.add(phrase) |
| return list(phrases) |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--text", type=str, default="Deep learning enables representation learning using neural networks, widely used in computer vision and NLP.") |
| parser.add_argument("--top_k", type=int, default=8) |
| args = parser.parse_args() |
|
|
| model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| cand = candidates(args.text)[:200] |
| if not cand: |
| print([]); return |
| emb_text = model.encode([args.text], normalize_embeddings=True) |
| emb_cand = model.encode(cand, normalize_embeddings=True) |
| scores = util.cos_sim(emb_text, emb_cand).cpu().numpy()[0] |
| top_idx = np.argsort(scores)[::-1][:args.top_k] |
| result = [(cand[i], float(scores[i])) for i in top_idx] |
| for k, s in result: |
| print(f"{k}\t{s:.3f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|