File size: 810 Bytes
8496edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from .base_agent import BaseAgent
from prompt.template import CREATE_CHART_PROMPT

class Chart(BaseAgent):
    def __init__(self, llm):
        super().__init__(llm)
    
    def create_single_chart(self, paper_content: str, existing_charts: str, user_prompt: str=''):
        prompt = CREATE_CHART_PROMPT.format(paper_content=paper_content, existing_charts=existing_charts, user_prompt=user_prompt)
        return self.llm.generate(prompt)

    def create_charts(self, paper_content: str, chart_num: int, user_prompt: str=''):
        existing_charts = ''
        charts = []
        for i in range(chart_num):
            chart = self.create_single_chart(paper_content, existing_charts, user_prompt)
            charts.append(chart)
            existing_charts = '\n---\n'.join(charts)
        return charts