File size: 8,187 Bytes
a0f27fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""Security paper pipeline.

Fetches security papers from arXiv (cs.CR + adjacent categories),
finds code URLs, and writes to the database.
"""

import logging
import re
import time
from datetime import datetime, timedelta, timezone

import arxiv
import requests

from src.config import (
    ADJACENT_CATEGORIES,
    GITHUB_TOKEN,
    GITHUB_URL_RE,
    SECURITY_EXCLUDE_RE,
    SECURITY_KEYWORDS,
    SECURITY_LLM_RE,
)
from src.db import create_run, finish_run, insert_papers

log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# arXiv fetching
# ---------------------------------------------------------------------------


def fetch_arxiv_papers(start: datetime, end: datetime, max_papers: int) -> list[dict]:
    """Fetch papers from arXiv: all cs.CR + security-filtered adjacent categories."""
    client = arxiv.Client(page_size=500, delay_seconds=3.0, num_retries=3)
    papers: dict[str, dict] = {}

    # Primary: all cs.CR papers
    log.info("Fetching cs.CR papers ...")
    cr_query = arxiv.Search(
        query="cat:cs.CR",
        max_results=max_papers,
        sort_by=arxiv.SortCriterion.SubmittedDate,
        sort_order=arxiv.SortOrder.Descending,
    )

    for result in client.results(cr_query):
        pub = result.published.replace(tzinfo=timezone.utc)
        if pub < start:
            break
        if pub > end:
            continue
        paper = _result_to_dict(result)
        papers[paper["entry_id"]] = paper

    log.info("cs.CR: %d papers", len(papers))

    # Adjacent categories with security keyword filter
    for cat in ADJACENT_CATEGORIES:
        adj_query = arxiv.Search(
            query=f"cat:{cat}",
            max_results=max_papers // len(ADJACENT_CATEGORIES),
            sort_by=arxiv.SortCriterion.SubmittedDate,
            sort_order=arxiv.SortOrder.Descending,
        )
        count = 0
        for result in client.results(adj_query):
            pub = result.published.replace(tzinfo=timezone.utc)
            if pub < start:
                break
            if pub > end:
                continue
            text = f"{result.title} {result.summary}"
            if SECURITY_KEYWORDS.search(text):
                paper = _result_to_dict(result)
                if paper["entry_id"] not in papers:
                    papers[paper["entry_id"]] = paper
                    count += 1
        log.info("  %s: %d security-relevant papers", cat, count)

    # Pre-filter: remove excluded topics (blockchain, surveys, etc.)
    before = len(papers)
    papers = {
        eid: p for eid, p in papers.items()
        if not SECURITY_EXCLUDE_RE.search(f"{p['title']} {p['abstract']}")
    }
    excluded = before - len(papers)
    if excluded:
        log.info("Excluded %d papers (blockchain/survey/off-topic)", excluded)

    # Tag LLM-adjacent papers so the scoring prompt can apply hard caps
    for p in papers.values():
        text = f"{p['title']} {p['abstract']}"
        p["llm_adjacent"] = bool(SECURITY_LLM_RE.search(text))

    llm_count = sum(1 for p in papers.values() if p["llm_adjacent"])
    if llm_count:
        log.info("Tagged %d papers as LLM-adjacent", llm_count)

    all_papers = list(papers.values())
    log.info("Total unique papers: %d", len(all_papers))
    return all_papers


def _result_to_dict(result: arxiv.Result) -> dict:
    """Convert an arxiv.Result to a plain dict."""
    arxiv_id = result.entry_id.split("/abs/")[-1]
    base_id = re.sub(r"v\d+$", "", arxiv_id)

    return {
        "arxiv_id": base_id,
        "entry_id": result.entry_id,
        "title": result.title.replace("\n", " ").strip(),
        "authors": [a.name for a in result.authors[:10]],
        "abstract": result.summary.replace("\n", " ").strip(),
        "published": result.published.isoformat(),
        "categories": list(result.categories),
        "pdf_url": result.pdf_url,
        "arxiv_url": result.entry_id,
        "comment": (result.comment or "").replace("\n", " ").strip(),
        "source": "arxiv",
        "github_repo": "",
        "github_stars": None,
        "hf_upvotes": 0,
        "hf_models": [],
        "hf_datasets": [],
        "hf_spaces": [],
    }


# ---------------------------------------------------------------------------
# Code URL finding
# ---------------------------------------------------------------------------


def extract_github_urls(paper: dict) -> list[str]:
    """Extract GitHub URLs from abstract and comments."""
    text = f"{paper['abstract']} {paper.get('comment', '')}"
    return list(set(GITHUB_URL_RE.findall(text)))


def search_github_for_paper(title: str, token: str | None) -> str | None:
    """Search GitHub for a repo matching the paper title."""
    headers = {"Accept": "application/vnd.github.v3+json"}
    if token:
        headers["Authorization"] = f"token {token}"

    if token:
        try:
            resp = requests.get("https://api.github.com/rate_limit", headers=headers, timeout=10)
            if resp.ok:
                remaining = resp.json().get("resources", {}).get("search", {}).get("remaining", 0)
                if remaining < 5:
                    return None
        except requests.RequestException:
            pass

    clean = re.sub(r"[^\w\s]", " ", title)
    words = clean.split()[:8]
    query = " ".join(words)

    try:
        resp = requests.get(
            "https://api.github.com/search/repositories",
            params={"q": query, "sort": "updated", "per_page": 3},
            headers=headers,
            timeout=10,
        )
        if not resp.ok:
            return None
        items = resp.json().get("items", [])
        if items:
            return items[0]["html_url"]
    except requests.RequestException:
        pass
    return None


def find_code_urls(papers: list[dict]) -> dict[str, str | None]:
    """Find code/repo URLs for each paper."""
    token = GITHUB_TOKEN or None
    code_urls: dict[str, str | None] = {}

    for paper in papers:
        urls = extract_github_urls(paper)
        if urls:
            code_urls[paper["entry_id"]] = urls[0]
            continue

        url = search_github_for_paper(paper["title"], token)
        code_urls[paper["entry_id"]] = url
        if not token:
            time.sleep(2)

    return code_urls


# ---------------------------------------------------------------------------
# Pipeline entry point
# ---------------------------------------------------------------------------


def run_security_pipeline(
    start: datetime | None = None,
    end: datetime | None = None,
    max_papers: int = 300,
) -> int:
    """Run the full security pipeline. Returns the run ID."""
    if end is None:
        end = datetime.now(timezone.utc)
    if start is None:
        start = end - timedelta(days=7)

    if start.tzinfo is None:
        start = start.replace(tzinfo=timezone.utc)
    if end.tzinfo is None:
        end = end.replace(tzinfo=timezone.utc, hour=23, minute=59, second=59)

    run_id = create_run("security", start.date().isoformat(), end.date().isoformat())
    log.info("Run %d: %s to %s", run_id, start.date(), end.date())

    try:
        # Step 1: Fetch papers
        papers = fetch_arxiv_papers(start, end, max_papers)

        if not papers:
            log.info("No papers found")
            finish_run(run_id, 0)
            return run_id

        # Step 2: Find code URLs
        log.info("Searching for code repositories ...")
        code_urls = find_code_urls(papers)
        with_code = sum(1 for v in code_urls.values() if v)
        log.info("Found code for %d/%d papers", with_code, len(papers))

        # Attach code URLs to papers as github_repo
        for paper in papers:
            url = code_urls.get(paper["entry_id"])
            if url:
                paper["github_repo"] = url

        # Step 3: Insert into DB
        insert_papers(papers, run_id, "security")
        finish_run(run_id, len(papers))
        log.info("Done — %d papers inserted", len(papers))
        return run_id

    except Exception as e:
        finish_run(run_id, 0, status="failed")
        log.exception("Pipeline failed")
        raise