Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import javalang | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import re | |
| import numpy as np | |
| import networkx as nx | |
| from transformers import AutoTokenizer, AutoModel | |
| from torch_geometric.data import Data | |
| from torch_geometric.nn import GCNConv | |
| import warnings | |
| import pandas as pd | |
| import zipfile | |
| import os | |
| from collections import defaultdict | |
| # Set up page config | |
| st.set_page_config( | |
| page_title="Advanced Java Code Clone Detector (IJaDataset 2.1)", | |
| page_icon="🔍", | |
| layout="wide" | |
| ) | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # Constants | |
| MODEL_NAME = "microsoft/codebert-base" | |
| MAX_LENGTH = 512 | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| DATASET_PATH = "archive (1).zip" # Update this path if needed | |
| # Initialize models with caching | |
| def load_models(): | |
| try: | |
| # Load CodeBERT for semantic analysis | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) | |
| # Initialize RNN model | |
| class RNNModel(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_layers): | |
| super(RNNModel, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) | |
| self.fc = nn.Linear(hidden_size, 1) | |
| def forward(self, x): | |
| h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE) | |
| out, _ = self.rnn(x, h0) | |
| out = self.fc(out[:, -1, :]) | |
| return out | |
| rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE) | |
| # Initialize GNN model | |
| class GNNModel(nn.Module): | |
| def __init__(self, node_features): | |
| super(GNNModel, self).__init__() | |
| self.conv1 = GCNConv(node_features, 128) | |
| self.conv2 = GCNConv(128, 64) | |
| self.fc = nn.Linear(64, 1) | |
| def forward(self, data): | |
| x, edge_index = data.x, data.edge_index | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = F.dropout(x, training=self.training) | |
| x = self.conv2(x, edge_index) | |
| x = self.fc(x) | |
| return torch.sigmoid(x.mean()) | |
| gnn_model = GNNModel(node_features=128).to(DEVICE) | |
| return tokenizer, code_model, rnn_model, gnn_model | |
| except Exception as e: | |
| st.error(f"Failed to load models: {str(e)}") | |
| return None, None, None, None | |
| def load_dataset(): | |
| try: | |
| # Extract dataset if needed | |
| if not os.path.exists("Diverse_100K_Dataset"): | |
| with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref: | |
| zip_ref.extractall(".") | |
| # Load sample pairs (modify this based on your dataset structure) | |
| clone_pairs = [] | |
| base_path = "Subject_CloneTypes_Directories" | |
| # Load pairs from all clone types | |
| for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]: | |
| type_path = os.path.join(base_path, clone_type) | |
| if os.path.exists(type_path): | |
| for root, _, files in os.walk(type_path): | |
| if files: | |
| # Take first two files as a pair | |
| if len(files) >= 2: | |
| with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1: | |
| code1 = f1.read() | |
| with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2: | |
| code2 = f2.read() | |
| clone_pairs.append({ | |
| "type": clone_type, | |
| "code1": code1, | |
| "code2": code2 | |
| }) | |
| break # Just take one pair per type for demo | |
| return clone_pairs[:10] # Return first 10 pairs for demo | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| return [] | |
| tokenizer, code_model, rnn_model, gnn_model = load_models() | |
| dataset_pairs = load_dataset() | |
| # AST Processing Functions | |
| def parse_ast(code): | |
| try: | |
| tokens = javalang.tokenizer.tokenize(code) | |
| parser = javalang.parser.Parser(tokens) | |
| tree = parser.parse() | |
| return tree | |
| except Exception as e: | |
| st.warning(f"AST parsing error: {str(e)}") | |
| return None | |
| def build_ast_graph(ast_tree): | |
| if not ast_tree: | |
| return None | |
| G = nx.DiGraph() | |
| node_id = 0 | |
| node_map = {} | |
| def traverse(node, parent_id=None): | |
| nonlocal node_id | |
| current_id = node_id | |
| node_label = str(type(node).__name__) | |
| node_map[current_id] = {'type': node_label, 'node': node} | |
| G.add_node(current_id, type=node_label) | |
| if parent_id is not None: | |
| G.add_edge(parent_id, current_id) | |
| node_id += 1 | |
| for child in node.children: | |
| if isinstance(child, javalang.ast.Node): | |
| traverse(child, current_id) | |
| elif isinstance(child, (list, tuple)): | |
| for item in child: | |
| if isinstance(item, javalang.ast.Node): | |
| traverse(item, current_id) | |
| traverse(ast_tree) | |
| return G, node_map | |
| def ast_to_pyg_data(ast_graph): | |
| if not ast_graph: | |
| return None | |
| # Convert AST to PyTorch Geometric Data format | |
| node_features = [] | |
| node_types = [] | |
| for node in ast_graph.nodes(): | |
| node_type = ast_graph.nodes[node]['type'] | |
| node_types.append(node_type) | |
| # Simple one-hot encoding of node types (in practice, use better encoding) | |
| feature = [0] * 50 # Assuming max 50 node types | |
| feature[hash(node_type) % 50] = 1 | |
| node_features.append(feature) | |
| # Convert networkx graph to edge_index format | |
| edge_index = list(ast_graph.edges()) | |
| if not edge_index: | |
| # Add self-loop if no edges | |
| edge_index = [(0, 0)] | |
| edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() | |
| x = torch.tensor(node_features, dtype=torch.float) | |
| return Data(x=x, edge_index=edge_index) | |
| # Normalization function | |
| def normalize_code(code): | |
| try: | |
| code = re.sub(r'//.*', '', code) # Remove single-line comments | |
| code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments | |
| code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace | |
| return code | |
| except Exception: | |
| return code | |
| # Feature extraction functions | |
| def get_lexical_features(code): | |
| """Extract lexical features (for Type-1 and Type-2 clones)""" | |
| normalized = normalize_code(code) | |
| tokens = re.findall(r'\b\w+\b', normalized) | |
| return { | |
| 'token_count': len(tokens), | |
| 'unique_tokens': len(set(tokens)), | |
| 'avg_token_length': np.mean([len(t) for t in tokens]) if tokens else 0 | |
| } | |
| def get_syntactic_features(ast_tree): | |
| """Extract syntactic features (for Type-3 clones)""" | |
| if not ast_tree: | |
| return {} | |
| # Count different node types in AST | |
| node_counts = defaultdict(int) | |
| def count_nodes(node): | |
| node_counts[type(node).__name__] += 1 | |
| for child in node.children: | |
| if isinstance(child, javalang.ast.Node): | |
| count_nodes(child) | |
| elif isinstance(child, (list, tuple)): | |
| for item in child: | |
| if isinstance(item, javalang.ast.Node): | |
| count_nodes(item) | |
| count_nodes(ast_tree) | |
| return dict(node_counts) | |
| def get_semantic_features(code): | |
| """Extract semantic features (for Type-4 clones)""" | |
| embedding = get_embedding(code) | |
| return embedding.cpu().numpy().flatten() if embedding is not None else None | |
| # Embedding generation | |
| def get_embedding(code): | |
| try: | |
| code = normalize_code(code) | |
| inputs = tokenizer( | |
| code, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding='max_length' | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = code_model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1) # Pooled embedding | |
| except Exception as e: | |
| st.error(f"Error processing code: {str(e)}") | |
| return None | |
| # Clone detection models | |
| def rnn_similarity(emb1, emb2): | |
| """Calculate similarity using RNN model""" | |
| if emb1 is None or emb2 is None: | |
| return None | |
| # Prepare input for RNN (sequence of embeddings) | |
| combined = torch.cat([emb1.unsqueeze(0), emb2.unsqueeze(0)], dim=0) | |
| with torch.no_grad(): | |
| similarity = rnn_model(combined.permute(1, 0, 2)) | |
| return torch.sigmoid(similarity).item() | |
| def gnn_similarity(ast1, ast2): | |
| """Calculate similarity using GNN model""" | |
| if ast1 is None or ast2 is None: | |
| return None | |
| data1 = ast_to_pyg_data(ast1) | |
| data2 = ast_to_pyg_data(ast2) | |
| if data1 is None or data2 is None: | |
| return None | |
| # Move data to device | |
| data1 = data1.to(DEVICE) | |
| data2 = data2.to(DEVICE) | |
| with torch.no_grad(): | |
| sim1 = gnn_model(data1) | |
| sim2 = gnn_model(data2) | |
| return F.cosine_similarity(sim1, sim2).item() | |
| def hybrid_similarity(code1, code2): | |
| """Combined similarity score using all models""" | |
| # Get embeddings | |
| emb1 = get_embedding(code1) | |
| emb2 = get_embedding(code2) | |
| # Parse ASTs | |
| ast_tree1 = parse_ast(code1) | |
| ast_tree2 = parse_ast(code2) | |
| ast_graph1 = build_ast_graph(ast_tree1) if ast_tree1 else None | |
| ast_graph2 = build_ast_graph(ast_tree2) if ast_tree2 else None | |
| # Calculate individual similarities | |
| codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0 | |
| rnn_sim = rnn_similarity(emb1, emb2) if emb1 is not None and emb2 is not None else 0 | |
| gnn_sim = gnn_similarity(ast_graph1[0] if ast_graph1 else None, | |
| ast_graph2[0] if ast_graph2 else None) or 0 | |
| # Combine with weights (can be tuned) | |
| weights = { | |
| 'codebert': 0.4, | |
| 'rnn': 0.3, | |
| 'gnn': 0.3 | |
| } | |
| combined = (weights['codebert'] * codebert_sim + | |
| weights['rnn'] * rnn_sim + | |
| weights['gnn'] * gnn_sim) | |
| return { | |
| 'combined': combined, | |
| 'codebert': codebert_sim, | |
| 'rnn': rnn_sim, | |
| 'gnn': gnn_sim | |
| } | |
| # Comparison function | |
| def compare_code(code1, code2): | |
| if not code1 or not code2: | |
| return None | |
| with st.spinner('Analyzing code with multiple techniques...'): | |
| # Get lexical features | |
| lex1 = get_lexical_features(code1) | |
| lex2 = get_lexical_features(code2) | |
| # Get AST trees | |
| ast_tree1 = parse_ast(code1) | |
| ast_tree2 = parse_ast(code2) | |
| # Get syntactic features | |
| syn1 = get_syntactic_features(ast_tree1) | |
| syn2 = get_syntactic_features(ast_tree2) | |
| # Get semantic features | |
| sem1 = get_semantic_features(code1) | |
| sem2 = get_semantic_features(code2) | |
| # Calculate hybrid similarity | |
| similarities = hybrid_similarity(code1, code2) | |
| return { | |
| 'similarities': similarities, | |
| 'lexical_features': (lex1, lex2), | |
| 'syntactic_features': (syn1, syn2), | |
| 'ast_trees': (ast_tree1, ast_tree2) | |
| } | |
| # UI Elements | |
| st.title("🔍 Advanced Java Code Clone Detector (IJaDataset 2.1)") | |
| st.markdown(""" | |
| Detect all types of code clones (Type 1-4) using hybrid approach with: | |
| - **CodeBERT** for semantic analysis | |
| - **RNN** for sequence modeling | |
| - **GNN** for AST structural analysis | |
| """) | |
| # Dataset selector | |
| selected_pair = None | |
| if dataset_pairs: | |
| pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)} | |
| selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys())) | |
| selected_pair = pair_options[selected_option] | |
| # Layout | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| code1 = st.text_area( | |
| "First Java Code", | |
| height=300, | |
| value=selected_pair["code1"] if selected_pair else "", | |
| help="Enter the first Java code snippet" | |
| ) | |
| with col2: | |
| code2 = st.text_area( | |
| "Second Java Code", | |
| height=300, | |
| value=selected_pair["code2"] if selected_pair else "", | |
| help="Enter the second Java code snippet" | |
| ) | |
| # Threshold sliders | |
| st.subheader("Detection Thresholds") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| threshold_type12 = st.slider( | |
| "Type 1/2 Threshold", | |
| min_value=0.5, | |
| max_value=1.0, | |
| value=0.9, | |
| step=0.01, | |
| help="Threshold for exact/syntactic clones" | |
| ) | |
| with col2: | |
| threshold_type3 = st.slider( | |
| "Type 3 Threshold", | |
| min_value=0.5, | |
| max_value=1.0, | |
| value=0.8, | |
| step=0.01, | |
| help="Threshold for near-miss clones" | |
| ) | |
| with col3: | |
| threshold_type4 = st.slider( | |
| "Type 4 Threshold", | |
| min_value=0.5, | |
| max_value=1.0, | |
| value=0.7, | |
| step=0.01, | |
| help="Threshold for semantic clones" | |
| ) | |
| # Compare button | |
| if st.button("Compare Code", type="primary"): | |
| if tokenizer is None or code_model is None or rnn_model is None or gnn_model is None: | |
| st.error("Models failed to load. Please check the logs.") | |
| else: | |
| result = compare_code(code1, code2) | |
| if result is not None: | |
| similarities = result['similarities'] | |
| lex1, lex2 = result['lexical_features'] | |
| syn1, syn2 = result['syntactic_features'] | |
| ast_tree1, ast_tree2 = result['ast_trees'] | |
| # Display results | |
| st.subheader("Detection Results") | |
| # Determine clone type | |
| combined_sim = similarities['combined'] | |
| clone_type = "No Clone" | |
| if combined_sim >= threshold_type12: | |
| clone_type = "Type 1/2 Clone (Exact/Near-Exact)" | |
| elif combined_sim >= threshold_type3: | |
| clone_type = "Type 3 Clone (Near-Miss)" | |
| elif combined_sim >= threshold_type4: | |
| clone_type = "Type 4 Clone (Semantic)" | |
| # Main metrics | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Combined Similarity", f"{combined_sim:.3f}") | |
| with col2: | |
| st.metric("Detected Clone Type", clone_type) | |
| with col3: | |
| st.metric("CodeBERT Similarity", f"{similarities['codebert']:.3f}") | |
| # Detailed metrics | |
| with st.expander("Detailed Similarity Scores"): | |
| cols = st.columns(3) | |
| with cols[0]: | |
| st.metric("RNN Similarity", f"{similarities['rnn']:.3f}") | |
| with cols[1]: | |
| st.metric("GNN Similarity", f"{similarities['gnn']:.3f}") | |
| with cols[2]: | |
| st.metric("Lexical Similarity", | |
| f"{sum(lex1[k] == lex2[k] for k in lex1)/max(len(lex1),1):.2f}") | |
| # Feature comparison | |
| with st.expander("Feature Analysis"): | |
| st.subheader("Lexical Features") | |
| lex_df = pd.DataFrame([lex1, lex2], index=["Code 1", "Code 2"]) | |
| st.dataframe(lex_df) | |
| st.subheader("Syntactic Features (AST Node Counts)") | |
| syn_df = pd.DataFrame([syn1, syn2], index=["Code 1", "Code 2"]).fillna(0) | |
| st.dataframe(syn_df) | |
| # AST Visualization | |
| if ast_tree1 and ast_tree2: | |
| with st.expander("AST Visualization (First 20 nodes)"): | |
| st.write("AST visualization would be implemented here with graphviz") | |
| # In a real implementation, you would use graphviz to render the ASTs | |
| # st.graphviz_chart(ast_to_graphviz(ast_tree1)) | |
| # st.graphviz_chart(ast_to_graphviz(ast_tree2)) | |
| # Normalized code view | |
| with st.expander("Show normalized code"): | |
| tab1, tab2 = st.tabs(["First Code", "Second Code"]) | |
| with tab1: | |
| st.code(normalize_code(code1)) | |
| with tab2: | |
| st.code(normalize_code(code2)) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| *Dataset Information*: | |
| - Using IJaDataset 2.1 from Kaggle | |
| - Contains 100K Java files with clone annotations | |
| - Clone types: Type-1, Type-2, Type-3, and Type-4 clones | |
| *Model Architecture*: | |
| - **CodeBERT**: Pre-trained model for semantic analysis | |
| - **RNN**: Processes token sequences for sequential patterns | |
| - **GNN**: Analyzes AST structure for syntactic patterns | |
| - **Hybrid Approach**: Combines all techniques for comprehensive detection | |
| """) |