File size: 1,554 Bytes
0b51cd7
 
 
 
 
 
 
917fc83
9f1f697
 
 
917fc83
 
 
 
0b51cd7
 
 
 
2441327
0b51cd7
 
 
 
2441327
0b51cd7
 
 
 
 
 
 
 
 
 
 
 
 
 
2441327
0b51cd7
 
 
2441327
0b51cd7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import textwrap
from summarise import load_pipe
from scrape     import scrape
from db         import get_conn
from typing import Optional,List
from helpers   import rows_by_tag

IDEA_PROMPT = (
   " You are a senior ML researcher. CONTEXT provides a list of papers. From this list of papers, propose THREE new research projects."
    "For each research project proposed, give a new Title, one-sentence on Motivation and background, two-sentences on the new method, "
    "and one-sentence on Evaluation method.\n"
    "===CONTEXT===\n"
    "{context}\n"
    "===PROJECT IDEAS===\n"
)

# ---------------------------------------------------------------------- #
def ideate_from_topic(topic: str, k: int = 8) -> Optional[str]:
    rows = rows_by_tag(topic, k)
    if not rows:
        return None

    ctx  = "\n".join(f"- {t}: {s}" for t, _, s, _ in rows)
    llm  = load_pipe()
    return llm(IDEA_PROMPT.format(context=ctx),
               do_sample=False)[0]['generated_text'].strip()

# ------------------------------------------------------------------ #
def ideate_from_ids(ids: List[str]) -> Optional[str]:
    from db import get_conn
    conn = get_conn()
    ctx = []
    for pid in ids:
        row = conn.execute(
            "SELECT title, summary FROM papers WHERE id=?", (pid,)
        ).fetchone()
        if row:
            ctx.append(f"- {row[0]}: {row[1]}")

    if not ctx:
        return None

    llm = load_pipe()
    return llm(IDEA_PROMPT.format(context="\n".join(ctx)),
               do_sample=False)[0]['generated_text'].strip()