File size: 19,737 Bytes
432dc67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
#!/usr/bin/env python3
"""
Multi-language Function Parsing Script
Scans code files in each repository, uses Qwen to parse dependencies and functions,
generates functions_with_context.csv
"""

import os
import sys
import json
import csv
import asyncio
import argparse
import hashlib
from pathlib import Path
from typing import List, Dict, Optional
from tqdm import tqdm
from dotenv import load_dotenv

# Load .env file (before importing other modules)
env_file = Path(__file__).parent / ".env"
if env_file.exists():
    load_dotenv(env_file)
elif (Path(__file__).parent.parent / ".env").exists():
    # If not in current directory, try loading from project root
    load_dotenv(Path(__file__).parent.parent / ".env")

# Add current directory to path (for importing schemas)
sys.path.insert(0, str(Path(__file__).parent))
# Add domain_code/src to path for reusing util functions
sys.path.insert(0, str(Path(__file__).parent.parent / "domain_code" / "src"))
from util import call_llm, init_logger, logger, CODE_EXTENSIONS
from schemas import FileParseResult

# Exclude markdown files (should not be parsed as code files)
PARSEABLE_CODE_EXTENSIONS = CODE_EXTENSIONS - {".md", ".markdown"}


# Default output filename (written back to repository directory)
CSV_FILENAME = "functions_with_context.csv"
SUMMARY_FILENAME = "README_SUMMARY.md"


def detect_language(file_path: Path) -> str:
    """
    Detect programming language based on file extension
    
    Args:
        file_path: File path
        
    Returns:
        Programming language name (e.g., python, cpp, java)
    """
    ext_map = {
        ".py": "python",
        ".ipynb": "python",
        ".java": "java",
        ".c": "c",
        ".cpp": "cpp",
        ".cc": "cpp",
        ".cxx": "cpp",
        ".h": "cpp",
        ".hpp": "cpp",
        ".hh": "cpp",
        ".F": "fortran",
        ".f90": "fortran",
        ".f": "fortran",
        ".f95": "fortran",
        ".r": "r",
        ".R": "r",
        ".m": "matlab",
        ".sh": "shell",
        ".bash": "shell",
        ".rs": "rust",
        ".go": "go",
        ".jl": "julia",
    }
    
    ext = file_path.suffix.lower()
    return ext_map.get(ext, ext.lstrip(".") if ext else "unknown")


def read_readme_summary(repo_dir: Path) -> Optional[str]:
    """
    Read README_SUMMARY.md content as project context
    
    Args:
        repo_dir: Repository root directory
        
    Returns:
        README summary text or None
    """
    summary_file = repo_dir / SUMMARY_FILENAME
    if not summary_file.exists():
        return None
    
    try:
        with open(summary_file, "r", encoding="utf-8", errors="ignore") as f:
            return f.read().strip()
    except Exception as e:
        logger.warning(f"Unable to read README summary file {summary_file}: {e}")
        return None


def find_code_files(repo_dir: Path, max_file_chars: int = 200000) -> List[Path]:
    """
    Find all code files in the repository (files covered by CODE_EXTENSIONS)
    
    Args:
        repo_dir: Repository root directory
        max_file_chars: Maximum file size (chars), files exceeding this are skipped
        
    Returns:
        List of code file paths
    """
    code_files = []
    
    for root, dirs, files in os.walk(repo_dir):
        # Skip hidden directories and common non-source directories
        dirs[:] = [d for d in dirs if not d.startswith(".") and d not in ["__pycache__", "node_modules", ".git"]]
        
        for file in files:
            file_path = Path(root) / file
            # Use PARSEABLE_CODE_EXTENSIONS to exclude markdown files
            if file_path.suffix.lower() in PARSEABLE_CODE_EXTENSIONS:
                # Check file size
                try:
                    size = file_path.stat().st_size
                    # Simple estimation: assume average 1 byte per char (UTF-8 encoding)
                    if size <= max_file_chars:
                        code_files.append(file_path)
                    else:
                        logger.debug(f"Skipping large file: {file_path} ({size} bytes)")
                except Exception as e:
                    logger.warning(f"Unable to get file size {file_path}: {e}")
    
    return sorted(code_files)


def read_code_file(file_path: Path) -> Optional[str]:
    """
    Read code file content
    
    Args:
        file_path: File path
        
    Returns:
        File content or None
    """
    try:
        with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()
    except Exception as e:
        logger.warning(f"Unable to read file {file_path}: {e}")
        return None


def compute_file_hash(file_path: Path, content: str) -> str:
    """
    Compute SHA1 hash of file
    
    Args:
        file_path: File path
        content: File content
        
    Returns:
        SHA1 hash (hex string)
    """
    return hashlib.sha1(content.encode("utf-8")).hexdigest()


def compute_function_hash(repo_name: str, path: str, start_line: int, end_line: int, body: str) -> str:
    """
    Compute function hash (for deduplication)
    
    Args:
        repo_name: Repository name
        path: Relative file path
        start_line: Function start line number
        end_line: Function end line number
        body: Function body
        
    Returns:
        SHA1 hash (hex string)
    """
    key = f"{repo_name}:{path}:{start_line}:{end_line}:{body}"
    return hashlib.sha1(key.encode("utf-8")).hexdigest()


async def parse_code_file(
    file_path: Path,
    repo_dir: Path,
    project_context: str,
    base_url: str,
    model: str,
    api_key: str,
    log_file: str,
) -> Optional[Dict]:
    """
    Use LLM to parse code file, extract dependencies and function information
    
    Args:
        file_path: Code file path
        repo_dir: Repository root directory
        project_context: Project context (README summary)
        base_url: LLM API base URL
        model: Model name
        api_key: API key
        log_file: Log file path
        
    Returns:
        Parse result (dict) or None
    """
    # Read code content
    code_content = read_code_file(file_path)
    if not code_content:
        return None
    
    # Detect language
    language = detect_language(file_path)
    
    # Build relative path
    rel_path = str(file_path.relative_to(repo_dir))
    
    # Read prompt template
    prompt_template_path = Path(__file__).parent / "prompts" / "function_extract.txt"
    try:
        with open(prompt_template_path, "r", encoding="utf-8") as f:
            prompt_template = f.read()
    except Exception as e:
        logger.error(f"Unable to read prompt template: {e}")
        return None
    
    # Build prompt
    prompt = prompt_template.format(
        project_context=project_context or "(No project context)",
        file_path=rel_path,
        language=language,
        code_content=code_content,
    )
    
    # Call LLM
    messages = [{"role": "user", "content": prompt}]
    
    try:
        result = await call_llm(
            messages=messages,
            model=model,
            base_url=base_url,
            api_key=api_key,
            pydantic_object=FileParseResult,
            log_file=log_file,
        )
        
        if result is None:
            logger.warning(f"LLM call returned None, skipping file: {rel_path}")
            return None
        
        # If result is a string, try to parse JSON
        if isinstance(result, str):
            try:
                result = json.loads(result)
            except json.JSONDecodeError:
                logger.warning(f"Unable to parse JSON from LLM response: {result[:200]}")
                return None
        
        # Add file path (ensure consistency)
        if isinstance(result, dict):
            result["file_path"] = rel_path
            result["language"] = language
        
        return result
    except Exception as e:
        logger.error(f"LLM call failed (file: {rel_path}): {e}")
        return None


def extract_repo_name(repo_dir: Path) -> str:
    """
    Extract repository name from directory name (owner___repo -> owner/repo)
    
    Args:
        repo_dir: Repository root directory
        
    Returns:
        Repository name (owner/repo format)
    """
    dir_name = repo_dir.name
    return dir_name.replace("___", "/")


async def process_single_repo(
    repo_dir: Path,
    base_url: str,
    model: str,
    api_key: str,
    log_file: str,
    max_file_chars: int = 200000,
    max_concurrency: int = 8,
    overwrite: bool = False,
) -> Dict[str, any]:
    """
    Process function parsing for a single repository
    
    Args:
        repo_dir: Repository root directory
        base_url: LLM API base URL
        model: Model name
        api_key: API key
        log_file: Log file path
        max_file_chars: Maximum file size (chars)
        max_concurrency: Maximum concurrency
        overwrite: Whether to overwrite existing CSV file
        
    Returns:
        Processing result dictionary
    """
    repo_name = repo_dir.name
    csv_file = repo_dir / CSV_FILENAME
    
    # Check if CSV file already exists
    if csv_file.exists() and not overwrite:
        return {
            "repo": repo_name,
            "status": "skipped",
            "reason": "CSV file already exists",
        }
    
    # Read README summary as project context
    project_context = read_readme_summary(repo_dir)
    if not project_context:
        logger.warning(f"Repository {repo_name} has no README_SUMMARY.md, skipping")
        return {
            "repo": repo_name,
            "status": "no_summary",
            "reason": "README_SUMMARY.md not found",
        }
    
    # Find code files
    code_files = find_code_files(repo_dir, max_file_chars=max_file_chars)
    if not code_files:
        return {
            "repo": repo_name,
            "status": "no_code",
            "reason": "No code files found",
        }
    
    logger.info(f"Repository {repo_name}: found {len(code_files)} code files")
    
    # Parse all code files
    semaphore = asyncio.Semaphore(max_concurrency)
    
    async def parse_with_semaphore(file_path: Path):
        async with semaphore:
            return await parse_code_file(
                file_path=file_path,
                repo_dir=repo_dir,
                project_context=project_context,
                base_url=base_url,
                model=model,
                api_key=api_key,
                log_file=log_file,
            )
    
    # Parse all files concurrently
    tasks = [parse_with_semaphore(file_path) for file_path in code_files]
    parse_results = []
    
    for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"Parsing {repo_name}", leave=False):
        result = await task
        if result:
            parse_results.append(result)
    
    if not parse_results:
        return {
            "repo": repo_name,
            "status": "parse_failed",
            "reason": "All files failed to parse",
        }
    
    # Generate CSV file
    repo_name_normalized = extract_repo_name(repo_dir)
    
    # CSV fields
    fieldnames = [
        "repo_name",
        "readme_summary_path",
        "readme_summary_text",
        "path",
        "language",
        "dependencies",
        "function_name",
        "function_start_line",
        "function_end_line",
        "function_body",
        "doc_start_line",
        "doc_end_line",
        "file_size_bytes",
        "file_sha1",
        "function_hash",
        "ds_source",
    ]
    
    # Write CSV
    try:
        with open(csv_file, "w", encoding="utf-8", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            
            function_count = 0
            for parse_result in parse_results:
                file_path = parse_result["file_path"]
                language = parse_result["language"]
                dependencies = parse_result.get("dependencies", [])
                functions = parse_result.get("functions", [])
                
                # Read file content (for hash and size calculation)
                full_file_path = repo_dir / file_path
                file_content = read_code_file(full_file_path)
                file_size = len(file_content.encode("utf-8")) if file_content else 0
                file_sha1 = compute_file_hash(full_file_path, file_content) if file_content else ""
                
                # Write a row for each function
                for func in functions:
                    function_name = func.get("function_name", "")
                    function_start_line = func.get("function_start_line", 0)
                    function_end_line = func.get("function_end_line", 0)
                    function_body = func.get("function_body", "")
                    doc_start_line = func.get("doc_start_line")
                    doc_end_line = func.get("doc_end_line")
                    
                    function_hash = compute_function_hash(
                        repo_name_normalized,
                        file_path,
                        function_start_line,
                        function_end_line,
                        function_body,
                    )
                    
                    # Truncate project_context (if too long)
                    context_text = project_context[:5000] if len(project_context) > 5000 else project_context
                    
                    row = {
                        "repo_name": repo_name_normalized,
                        "readme_summary_path": SUMMARY_FILENAME,
                        "readme_summary_text": context_text,
                        "path": file_path,
                        "language": language,
                        "dependencies": json.dumps(dependencies, ensure_ascii=False),
                        "function_name": function_name,
                        "function_start_line": function_start_line,
                        "function_end_line": function_end_line,
                        "function_body": function_body,
                        "doc_start_line": doc_start_line if doc_start_line else "",
                        "doc_end_line": doc_end_line if doc_end_line else "",
                        "file_size_bytes": file_size,
                        "file_sha1": file_sha1,
                        "function_hash": function_hash,
                        "ds_source": "repos_filtered",
                    }
                    
                    writer.writerow(row)
                    function_count += 1
        
        logger.info(f"Repository {repo_name}: wrote {function_count} functions to {csv_file}")
        
        return {
            "repo": repo_name,
            "status": "success",
            "csv_file": str(csv_file),
            "file_count": len(code_files),
            "function_count": function_count,
        }
    except Exception as e:
        logger.error(f"Unable to write CSV file {csv_file}: {e}")
        return {
            "repo": repo_name,
            "status": "write_failed",
            "reason": str(e),
        }


async def process_all_repos(
    repos_dir: Path,
    base_url: str,
    model: str,
    api_key: str,
    log_file: str,
    max_file_chars: int = 200000,
    max_concurrency: int = 8,
    overwrite: bool = False,
) -> List[Dict]:
    """
    Process function parsing for all repositories
    
    Args:
        repos_dir: Repository root directory
        base_url: LLM API base URL
        model: Model name
        api_key: API key
        log_file: Log file path
        max_file_chars: Maximum file size (chars)
        max_concurrency: Maximum concurrency
        overwrite: Whether to overwrite existing CSV files
        
    Returns:
        List of processing results for all repositories
    """
    # Get all repository directories
    repo_dirs = [
        d for d in repos_dir.iterdir() 
        if d.is_dir() and not d.name.startswith(".")
    ]
    repo_dirs.sort()
    
    logger.info(f"Found {len(repo_dirs)} repositories, starting processing...")
    
    # Process each repository sequentially (concurrency is controlled at file level)
    results = []
    
    for repo_dir in tqdm(repo_dirs, desc="Processing repos"):
        result = await process_single_repo(
            repo_dir=repo_dir,
            base_url=base_url,
            model=model,
            api_key=api_key,
            log_file=log_file,
            max_file_chars=max_file_chars,
            max_concurrency=max_concurrency,
            overwrite=overwrite,
        )
        results.append(result)
    
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Multi-language Function Parsing Tool")
    parser.add_argument(
        "--repos_dir",
        type=str,
        default="/home/weifengsun/tangou1/domain_code/src/workdir/repos_filtered",
        help="Repository root directory path",
    )
    parser.add_argument(
        "--base_url",
        type=str,
        default=os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1"),
        help="LLM API base URL (default: http://localhost:8000/v1)",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="Qwen3",
        help="Model name (default: Qwen3)",
    )
    parser.add_argument(
        "--api_key_env",
        type=str,
        default="OPENAI_API_KEY",
        help="API key environment variable name (default: OPENAI_API_KEY)",
    )
    parser.add_argument(
        "--max_concurrency",
        type=int,
        default=8,
        help="Maximum concurrency (default: 8)",
    )
    parser.add_argument(
        "--max_file_chars",
        type=int,
        default=200000,
        help="Maximum file size in chars (default: 200000)",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing CSV files",
    )
    parser.add_argument(
        "--log_file",
        type=str,
        default="instruction_generation/workdir/logs/extract.log",
        help="Log file path",
    )
    
    args = parser.parse_args()
    
    # Initialize logger
    init_logger(args.log_file, level="INFO")
    
    # Get API key
    api_key = os.getenv(args.api_key_env, "none")
    
    # Process all repositories
    repos_dir = Path(args.repos_dir)
    if not repos_dir.exists():
        logger.error(f"Repository directory does not exist: {repos_dir}")
        sys.exit(1)
    
    # Create log directory
    log_file_path = Path(args.log_file)
    log_file_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Run main logic
    results = asyncio.run(
        process_all_repos(
            repos_dir=repos_dir,
            base_url=args.base_url,
            model=args.model,
            api_key=api_key,
            log_file=str(log_file_path),
            max_file_chars=args.max_file_chars,
            max_concurrency=args.max_concurrency,
            overwrite=args.overwrite,
        )
    )
    
    # Statistics
    status_counts = {}
    total_functions = 0
    for result in results:
        status = result["status"]
        status_counts[status] = status_counts.get(status, 0) + 1
        if "function_count" in result:
            total_functions += result["function_count"]
    
    logger.info("\n" + "=" * 80)
    logger.info("Processing complete!")
    logger.info("=" * 80)
    logger.info(f"Total: {len(results)} repositories")
    logger.info(f"Total: {total_functions} functions")
    for status, count in status_counts.items():
        logger.info(f"  {status}: {count}")
    logger.info("=" * 80)