Spaces:
Runtime error
Runtime error
| import os | |
| import streamlit as st | |
| import javalang | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import re | |
| import numpy as np | |
| import networkx as nx | |
| from transformers import AutoTokenizer, AutoModel | |
| import warnings | |
| import pandas as pd | |
| from collections import defaultdict | |
| # Configuration | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| # Constants | |
| MODEL_NAME = "microsoft/codebert-base" | |
| MAX_LENGTH = 512 | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Set up page config | |
| st.set_page_config( | |
| page_title="Java Code Clone Detector", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Simplified RNN Model (for Hugging Face compatibility) | |
| class SimpleRNN(nn.Module): | |
| def __init__(self, input_size=768, hidden_size=128): | |
| super().__init__() | |
| self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) | |
| self.fc = nn.Linear(hidden_size, 1) | |
| def forward(self, x): | |
| out, _ = self.rnn(x) | |
| return torch.sigmoid(self.fc(out[:, -1])) | |
| # Model Loading with caching | |
| def load_models(): | |
| try: | |
| with st.spinner('Loading models (first run may take a few minutes)...'): | |
| # Load CodeBERT | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) | |
| # Initialize simple RNN | |
| rnn_model = SimpleRNN().to(DEVICE) | |
| return tokenizer, code_model, rnn_model | |
| except Exception as e: | |
| st.error(f"Model loading failed: {str(e)}") | |
| return None, None, None | |
| # AST Processing (simplified for Hugging Face) | |
| def parse_ast(code): | |
| try: | |
| return javalang.parse.parse(code) | |
| except: | |
| return None | |
| def build_simple_ast_features(ast_tree): | |
| if not ast_tree: return {} | |
| features = defaultdict(int) | |
| def traverse(node): | |
| features[type(node).__name__] += 1 | |
| for child in getattr(node, 'children', []): | |
| if isinstance(child, javalang.ast.Node): | |
| traverse(child) | |
| elif isinstance(child, (list, tuple)): | |
| for item in child: | |
| if isinstance(item, javalang.ast.Node): | |
| traverse(item) | |
| traverse(ast_tree) | |
| return dict(features) | |
| # Feature Extraction | |
| def normalize_code(code): | |
| code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE) | |
| code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) | |
| return re.sub(r'\s+', ' ', code).strip() | |
| def get_embedding(code, tokenizer, model): | |
| try: | |
| inputs = tokenizer( | |
| normalize_code(code), | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding='max_length' | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| return model(**inputs).last_hidden_state.mean(dim=1) | |
| except: | |
| return None | |
| # Similarity Calculations (optimized for Hugging Face) | |
| def calculate_similarities(code1, code2, models): | |
| tokenizer, code_model, rnn_model = models | |
| # Get embeddings | |
| emb1 = get_embedding(code1, tokenizer, code_model) | |
| emb2 = get_embedding(code2, tokenizer, code_model) | |
| # Get AST features | |
| ast1 = parse_ast(code1) | |
| ast2 = parse_ast(code2) | |
| ast_features1 = build_simple_ast_features(ast1) | |
| ast_features2 = build_simple_ast_features(ast2) | |
| # Calculate similarities | |
| codebert_sim = 0 | |
| if emb1 is not None and emb2 is not None: | |
| codebert_sim = F.cosine_similarity(emb1, emb2).item() | |
| rnn_sim = 0 | |
| if emb1 is not None and emb2 is not None: | |
| with torch.no_grad(): | |
| rnn_input = torch.cat([emb1, emb2]).unsqueeze(0) | |
| rnn_sim = rnn_model(rnn_input).item() | |
| # Simple AST similarity (count matching node types) | |
| ast_sim = 0 | |
| if ast_features1 and ast_features2: | |
| common_keys = set(ast_features1.keys()) & set(ast_features2.keys()) | |
| total_keys = set(ast_features1.keys()) | set(ast_features2.keys()) | |
| ast_sim = len(common_keys) / len(total_keys) if total_keys else 0 | |
| return { | |
| 'codebert': codebert_sim, | |
| 'rnn': rnn_sim, | |
| 'ast': ast_sim, | |
| 'combined': 0.5*codebert_sim + 0.3*rnn_sim + 0.2*ast_sim | |
| } | |
| # Main UI | |
| def main(): | |
| st.title("π Java Code Clone Detector (IJaDataset 2.1)") | |
| st.markdown("Detect Type 1-4 clones using hybrid analysis") | |
| # Load models | |
| models = load_models() | |
| if None in models: | |
| st.error("Failed to load required models. Please check the logs.") | |
| return | |
| # Example code pairs | |
| example_pairs = { | |
| "Type 1 Example": { | |
| "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }", | |
| "code2": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }" | |
| }, | |
| "Type 2 Example": { | |
| "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }", | |
| "code2": "public class Example { public static void main(String[] args) { System.out.println(\"Hello\"); } }" | |
| }, | |
| "Type 3 Example": { | |
| "code1": "public class Test { public static void main(String[] args) { for(int i=0;i<10;i++) System.out.println(i); } }", | |
| "code2": "public class Example { public static void run(String[] params) { for(int j=0;j<10;j++) System.out.println(j); } }" | |
| } | |
| } | |
| # Code input | |
| selected_example = st.selectbox("Select example pair:", list(example_pairs.keys())) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| code1 = st.text_area( | |
| "Code 1", | |
| height=300, | |
| value=example_pairs[selected_example]["code1"] | |
| ) | |
| with col2: | |
| code2 = st.text_area( | |
| "Code 2", | |
| height=300, | |
| value=example_pairs[selected_example]["code2"] | |
| ) | |
| # Thresholds | |
| st.subheader("Detection Thresholds") | |
| cols = st.columns(3) | |
| with cols[0]: | |
| t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95) | |
| with cols[1]: | |
| t3 = st.slider("Type 3", 0.7, 0.9, 0.8) | |
| with cols[2]: | |
| t4 = st.slider("Type 4", 0.5, 0.8, 0.65) | |
| # Analysis button | |
| if st.button("Analyze Code", type="primary"): | |
| with st.spinner("Analyzing code..."): | |
| sims = calculate_similarities(code1, code2, models) | |
| # Determine clone type | |
| clone_type = "No Clone" | |
| if sims['combined'] >= t1: | |
| clone_type = "Type 1/2 Clone (Exact/Near-Exact)" | |
| elif sims['combined'] >= t3: | |
| clone_type = "Type 3 Clone (Near-Miss)" | |
| elif sims['combined'] >= t4: | |
| clone_type = "Type 4 Clone (Semantic)" | |
| # Display results | |
| st.subheader("Results") | |
| # Metrics | |
| cols = st.columns(4) | |
| cols[0].metric("Combined", f"{sims['combined']:.2f}") | |
| cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}") | |
| cols[2].metric("RNN", f"{sims['rnn']:.2f}") | |
| cols[3].metric("AST", f"{sims['ast']:.2f}") | |
| # Progress bar | |
| st.progress(sims['combined']) | |
| # Final result | |
| st.metric("Detection Result", clone_type) | |
| # Show details | |
| with st.expander("Advanced Details"): | |
| st.json(sims) | |
| st.code(f"Normalized Code 1:\n{normalize_code(code1)}") | |
| st.code(f"Normalized Code 2:\n{normalize_code(code2)}") | |
| if __name__ == "__main__": | |
| main() |