| | |
| | import gradio as gr |
| | import torch |
| | import random |
| | from sentence_transformers import SentenceTransformer, util |
| | from datasets import load_dataset |
| | from spaces import GPU |
| | import re |
| |
|
| |
|
| |
|
| | |
| | model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Crow-v1.1-Plus") |
| | model.eval() |
| |
|
| | |
| | dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test") |
| |
|
| | def remove_comments_from_code(code: str) -> str: |
| | |
| | code = re.sub(r'"""[\s\S]*?"""', '', code) |
| | code = re.sub(r"'''[\s\S]*?'''", '', code) |
| |
|
| |
|
| | return code |
| |
|
| |
|
| | |
| | def get_query_and_candidates(seed: int = 8520): |
| | random.seed(seed) |
| | idx = random.randint(0, len(dataset) - 1) |
| | query = dataset[idx] |
| | correct_code = remove_comments_from_code(query["code"]) |
| | doc_str = query["docstring"] |
| |
|
| | candidate_pool = [example for i, example in enumerate(dataset) if i != idx] |
| | negatives = random.sample(candidate_pool, k=99) |
| | candidates = [correct_code] + [remove_comments_from_code(neg["code"]) for neg in negatives] |
| | random.shuffle(candidates) |
| |
|
| | return doc_str, correct_code, candidates |
| |
|
| |
|
| |
|
| | @GPU |
| | def code_search_demo(seed: int): |
| | doc_str, correct_code, candidates = get_query_and_candidates(seed) |
| |
|
| | query_emb = model.encode(doc_str, convert_to_tensor=True) |
| | candidate_embeddings = model.encode(candidates, convert_to_tensor=True) |
| |
|
| | cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0] |
| | results = sorted(zip(candidates, cos_scores), key=lambda x: x[1], reverse=True) |
| |
|
| | top_k = 10 |
| | correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k]) |
| | mrr = 0.0 |
| | for rank, (code, _) in enumerate(results, start=1): |
| | if code.strip() == correct_code.strip(): |
| | mrr = 1.0 / rank |
| | break |
| |
|
| | output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n" |
| | output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n" |
| | output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n" |
| | output += "## 🏆 Top Matches:\n" |
| |
|
| | medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))] |
| | for i, (code, score) in enumerate(results): |
| | label = medals[i] if i < len(medals) else f"#{i+1}" |
| | is_correct = "✅" if code.strip() == correct_code.strip() else "" |
| | output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n" |
| |
|
| | return output |
| |
|
| |
|
| | |
| | demo = gr.Interface( |
| | fn=code_search_demo, |
| | inputs=gr.Slider(0, 100000, value=8520, step=1, label="Random Seed"), |
| | outputs=gr.Markdown(label="Search Result"), |
| | title="🔎 CodeSearch-ModernBERT-Owl🦉 Demo", |
| | description="docstring から類似 Python 関数を検索(CodeXGlue + ModernBERT-Owl)" |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|