Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| # 1. Load Model from Hugging Face (Your Team's Checkpoint) | |
| MODEL_NAME = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph" | |
| import time | |
| print(f"Downloading model: {MODEL_NAME}...") | |
| MAX_RETRIES = 3 | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| print(f"Attempt {attempt+1}/{MAX_RETRIES}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModel.from_pretrained(MODEL_NAME) | |
| print("Model loaded successfully!") | |
| break | |
| except Exception as e: | |
| print(f"Attempt {attempt+1} failed: {e}") | |
| if attempt == MAX_RETRIES - 1: | |
| print("Failed to load model after multiple attempts.") | |
| print("Tip: Check internet connection or repo visibility.") | |
| exit(1) | |
| time.sleep(5) # Wait before retry | |
| # 2. Define Inputs (Query vs Code) | |
| query = "How to create a state graph in langgraph?" | |
| code = """ | |
| from langgraph.graph import StateGraph | |
| def create_workflow(): | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("agent", agent_node) | |
| return workflow.compile() | |
| """ | |
| irrelevant_code = "def fast_inverse_sqrt(number): return number ** -0.5" | |
| # 3. Embed & Compare | |
| def embed(text): | |
| inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Mean pooling for sentence representation | |
| embeddings = outputs.last_hidden_state.mean(dim=1) | |
| return F.normalize(embeddings, p=2, dim=1) | |
| print("\nRunning Inference Test...") | |
| query_emb = embed(query) | |
| code_emb = embed(code) | |
| irrelevant_emb = embed(irrelevant_code) | |
| # 4. Calculate Similarity | |
| sim_positive = F.cosine_similarity(query_emb, code_emb).item() | |
| sim_negative = F.cosine_similarity(query_emb, irrelevant_emb).item() | |
| print(f"Query: '{query}'") | |
| print(f"Similarity to Relevant Code: {sim_positive:.4f} (Should be high)") | |
| print(f"Similarity to Irrelevant Code: {sim_negative:.4f} (Should be low)") | |
| if sim_positive > sim_negative: | |
| print("\nSUCCESS: Model correctly ranks relevant code higher.") | |
| else: | |
| print("\n⚠️ WARNING: Model performance might be poor.") | |