| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel |
|
|
| |
| MODEL_NAME = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph" |
|
|
| import time |
|
|
| print(f"Downloading model: {MODEL_NAME}...") |
| MAX_RETRIES = 3 |
| for attempt in range(MAX_RETRIES): |
| try: |
| print(f"Attempt {attempt+1}/{MAX_RETRIES}...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModel.from_pretrained(MODEL_NAME) |
| print("Model loaded successfully!") |
| break |
| except Exception as e: |
| print(f"Attempt {attempt+1} failed: {e}") |
| if attempt == MAX_RETRIES - 1: |
| print("Failed to load model after multiple attempts.") |
| print("Tip: Check internet connection or repo visibility.") |
| exit(1) |
| time.sleep(5) |
|
|
| |
| query = "How to create a state graph in langgraph?" |
| code = """ |
| from langgraph.graph import StateGraph |
| |
| def create_workflow(): |
| workflow = StateGraph(AgentState) |
| workflow.add_node("agent", agent_node) |
| return workflow.compile() |
| """ |
| irrelevant_code = "def fast_inverse_sqrt(number): return number ** -0.5" |
|
|
| |
| def embed(text): |
| inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| embeddings = outputs.last_hidden_state.mean(dim=1) |
| return F.normalize(embeddings, p=2, dim=1) |
|
|
| print("\nRunning Inference Test...") |
| query_emb = embed(query) |
| code_emb = embed(code) |
| irrelevant_emb = embed(irrelevant_code) |
|
|
| |
| sim_positive = F.cosine_similarity(query_emb, code_emb).item() |
| sim_negative = F.cosine_similarity(query_emb, irrelevant_emb).item() |
|
|
| print(f"Query: '{query}'") |
| print(f"Similarity to Relevant Code: {sim_positive:.4f} (Should be high)") |
| print(f"Similarity to Irrelevant Code: {sim_negative:.4f} (Should be low)") |
|
|
| if sim_positive > sim_negative: |
| print("\nSUCCESS: Model correctly ranks relevant code higher.") |
| else: |
| print("\n⚠️ WARNING: Model performance might be poor.") |
|
|