remiai3's picture
Upload 4 files
d67a888 verified
import argparse, re, itertools
import numpy as np
from sentence_transformers import SentenceTransformer, util
def candidates(text):
# crude noun-phrase-ish candidates: sequences of words (2-4 words) without punctuation
words = re.findall(r"[A-Za-z][A-Za-z\-']+", text.lower())
phrases = set()
for n in range(1, 4):
for i in range(len(words)-n+1):
phrase = " ".join(words[i:i+n])
if len(phrase) > 3:
phrases.add(phrase)
return list(phrases)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--text", type=str, default="Deep learning enables representation learning using neural networks, widely used in computer vision and NLP.")
parser.add_argument("--top_k", type=int, default=8)
args = parser.parse_args()
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
cand = candidates(args.text)[:200]
if not cand:
print([]); return
emb_text = model.encode([args.text], normalize_embeddings=True)
emb_cand = model.encode(cand, normalize_embeddings=True)
scores = util.cos_sim(emb_text, emb_cand).cpu().numpy()[0]
top_idx = np.argsort(scores)[::-1][:args.top_k]
result = [(cand[i], float(scores[i])) for i in top_idx]
for k, s in result:
print(f"{k}\t{s:.3f}")
if __name__ == "__main__":
main()