Spaces:
Sleeping
Sleeping
| import ui | |
| from typing import List | |
| from pydantic import BaseModel, Field | |
| import time | |
| import gradio as gr | |
| from llm_config import call_llm, get_llm_usage | |
| import prompts | |
| from colorama import Fore, Style | |
| # Add these imports at the top | |
| from search import fetch_search_results | |
| from improve_content import ImproveContent | |
| import re | |
| # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| class Section(BaseModel): | |
| name: str = Field( | |
| description="Name for this section of the report.", | |
| ) | |
| description: str = Field( | |
| description="Brief overview of the main topics and concepts to be covered in this section.", | |
| ) | |
| questions: List[str] = Field( | |
| description="Key Questions to answer in this section." | |
| ) | |
| content: str = Field( | |
| description="The content of the section." | |
| ) | |
| class Sections(BaseModel): | |
| sections: List[Section] = Field( | |
| description="Sections of the report.", | |
| ) | |
| def as_str(self) -> str: | |
| subsections = "\n\n".join( | |
| f"## {section.name}\n\n-{section.description}\n\n- Questions: {'\n\n'.join(section.questions)}\n\n- Content: {section.content}\n" | |
| for section in self.sections or [] | |
| ) | |
| return subsections | |
| def print_sections(self) -> str: | |
| return '\n\n'.join([s.content for s in self.sections]) | |
| class ResearchArea(BaseModel): | |
| area : str = Field(..., title="Research Area") | |
| search_terms : str = Field(..., title = "Search Term", description = "Search query that will help you find information") | |
| class ResearchFocus(BaseModel): | |
| areas : List[ResearchArea] = Field(..., title="Research Areas") | |
| class RelevantSearchResults(BaseModel): | |
| relevant_search_results : List[int] = Field(..., title="Relevant Search Results", description="The position of the search result in the search results list") | |
| reasoning : List[str] = Field(..., title="Reasoning", description="Reasoning for selecting the search results") | |
| class SearchTerm(BaseModel): | |
| query : str = Field(..., title="Search Query") | |
| #time_range : str = Field(..., title="Time Range", description="d/w/m/y/none") | |
| class SearchTermsList(BaseModel): | |
| queries : List[str] = Field(..., title="Search Terms as a list") | |
| class Editor(BaseModel): | |
| name: str = Field( | |
| description="Name of the editor.", | |
| ) | |
| affiliation: str = Field( | |
| description="Primary affiliation of the editor.", | |
| ) | |
| role: str = Field( | |
| description="Role of the editor in the context of the topic.", | |
| ) | |
| focus: str = Field( | |
| description="Description of the editor's focus area, concerns and how they will help.", | |
| ) | |
| def persona(self) -> str: | |
| return f"\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.focus}\n" | |
| class Perspectives(BaseModel): | |
| editors: List[Editor] = Field( | |
| description="Comprehensive list of editors with their roles and affiliations.", | |
| ) | |
| class ReportSynopsis(BaseModel): | |
| synopsis: str= Field(..., title="Report Synopsis", description="A synopsis talking about what the reader can expect") | |
| class SectionContent(BaseModel): | |
| content: str = Field(..., title="Section Content", description="The content of the section") | |
| class ResearchManager: | |
| """Manages the research process including analysis, search, and documentation""" | |
| def __init__(self, research_task): | |
| self.use_existing_outline = True | |
| self.research_task = research_task | |
| self.report_synopsis = '' | |
| self.personas = '' | |
| self.gradio_report_outline = '' | |
| self.task_status = { | |
| 'synopsis_draft' : {"name": "Creating synopsis of the report...", "status": "pending"}, | |
| 'gathering_info' : {"name": "Gathering Info on the topic...", "status": "pending"}, | |
| 'running_searches' : {"name": "Run search...", "status": "pending"}, | |
| 'mock_discussion' : {"name": "Conducting mock discussions...", "status": "pending"}, | |
| 'generating_outline': {"name": "Generating a draft outline...", "status": "pending"}, | |
| } | |
| def extract_citation_info(self,text): | |
| """ | |
| Extract citation number and URL from citation text | |
| """ | |
| references = {} | |
| for ref in text: | |
| # Find citation number | |
| citation_match = re.search(r'\[(\d+)\]', ref) | |
| citation_number = citation_match.group(1) if citation_match else None | |
| # Find URL | |
| url_match = re.search(r'URL: (https?://\S+)', ref) | |
| url = url_match.group(1) if url_match else None | |
| references[citation_number] = { | |
| 'url': url, | |
| 'reference_text': ref | |
| } | |
| return references | |
| def section_writer(self, section: Section): | |
| """Given an outline of a section, generate search queries, | |
| perform searches and generate the section content""" | |
| improve_content = ImproveContent(section.name, | |
| section.description, | |
| section.questions, | |
| self.personas.editors | |
| ) | |
| improved_content = yield from improve_content.create_and_run_interview(self.task_status, self.update_gradio) | |
| content, references = improve_content.generate_final_section(self.report_synopsis) | |
| self.task_status[section.name]["name"] = "Writing Section: " + section.name | |
| yield from self.update_gradio() | |
| ui.system_update(f"Writing Section: {section.name}") | |
| section_content = call_llm( | |
| instructions=prompts.WRITE_SECTION_INSTRUCTIONS, | |
| model_type='slow', | |
| context={"section_description": section.description, | |
| "gathered_info" : '\n\n'.join(content), | |
| "topic": self.research_task['topic'], | |
| "section_title" : section.name, | |
| "synopsis" : self.report_synopsis, | |
| "section_questions" : '\n'.join(section.questions), | |
| 'report_type': self.research_task['report_type'], | |
| 'section_length': self.research_task['section_length']}, | |
| response_model=SectionContent, | |
| logging_fn='write_section_instructions' | |
| ) | |
| #references = '\n\n'.join(references) | |
| references_dict = self.extract_citation_info(references.split('\n\n')) | |
| #Replacing citations with [2,3,4] format with [2][3][4] | |
| cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', section_content.content) | |
| for group in cited_references_raw: | |
| nums_list = group.split(',') | |
| new_string = ''.join(f'[{n.strip()}]' for n in nums_list) | |
| old_string = f'[{group}]' | |
| section_content.content = section_content.content.replace(old_string, new_string) | |
| parsed_cited_references = [] | |
| for ref_group in cited_references_raw: | |
| for ref_no in ref_group.split(','): | |
| parsed_cited_references.append(ref_no.strip()) | |
| used_references = {} | |
| uncited_sources= [] | |
| for reference_no in parsed_cited_references: | |
| reference = references_dict.get(reference_no) | |
| if reference: | |
| used_references[reference_no] = reference | |
| else: | |
| print(f"Reference {reference_no} not found") | |
| uncited_sources.append(reference_no) | |
| section_content.content = section_content.content.replace(f"[{reference_no}]", "[!]") | |
| for ref_no, data in used_references.items(): | |
| if data["url"]: | |
| section_content.content = section_content.content.replace(f"[{ref_no}]", f"[[{ref_no}]]({data['url']})") | |
| section.content = section_content.content | |
| print(section_content.content) | |
| self.task_status[section.name]["status"] = "done" | |
| yield from self.update_gradio(report_outline_str=self.report_outline.print_sections(), button_disable=False) | |
| ui.system_update("Waiting for 5 seconds before next section") | |
| time.sleep(5) | |
| return section | |
| def _generate_report_outline(self): | |
| """Use LLM to generate focus areas for research based on the original query""" | |
| ui.system_update(f"\nGathering Context..") | |
| self.task_status['gathering_info']["status"] = "running" | |
| yield from self.update_gradio() | |
| queries = call_llm( | |
| instructions=prompts.FIND_SEARCH_TERMS_INSTRUCTIONS, | |
| model_type='fast', | |
| context={ | |
| "report_type": self.research_task['report_type'], | |
| "original_query": self.research_task['topic'], | |
| "report_synopsis": self.report_synopsis, | |
| }, | |
| response_model=SearchTermsList, | |
| logging_fn='find_search_terms_instructions' | |
| ) | |
| self.task_status['running_searches']["status"] = "running" | |
| yield from self.update_gradio() | |
| formatted_results, results = yield from fetch_search_results(query=queries.queries, | |
| task_status=self.task_status, | |
| task_name = 'running_searches', | |
| fn = self.update_gradio) | |
| self.context = formatted_results | |
| self.task_status['running_searches']["status"] = "done" | |
| self.task_status['gathering_info']["status"] = "done" | |
| self.task_status['mock_discussion']["status"] = "running" | |
| yield from self.update_gradio() | |
| personas = call_llm( | |
| instructions=prompts.GENERATE_ROUNDTABLE_PERSONAS_INSTRUCTIONS, | |
| model_type='slow', | |
| context={"context": self.context, | |
| "topic": self.research_task['topic'], | |
| "report_synopsis": self.report_synopsis, | |
| 'type_of_report': self.research_task['report_type'], | |
| 'num_personas': 5}, | |
| response_model=Perspectives, | |
| logging_fn='generate_roundtable_personas_instructions' | |
| ) | |
| self.task_status['mock_discussion']["name"] = "Started discussions..." | |
| print(personas) | |
| yield from self.update_gradio() | |
| improve_content = ImproveContent(self.research_task['topic'], | |
| "This section will focus on a comprehensive overview of glean", | |
| self.research_task['key_questions'], | |
| personas.editors) | |
| warm_start_discussion = improve_content.warm_start_discussion() | |
| self.task_status['mock_discussion']["name"] = "Mock discussions complete" | |
| self.task_status['mock_discussion']["status"] = "done" | |
| self.task_status['generating_outline']["status"] = "running" | |
| yield from self.update_gradio() | |
| ui.system_update("\nGenerating Report Outline..") | |
| report_outline = call_llm( | |
| instructions=prompts.GENERATE_REPORT_OUTLINE_INSTRUCTIONS, | |
| model_type='slow', | |
| context={ | |
| "report_type": self.research_task['report_type'], | |
| "topic": self.research_task['topic'], | |
| "context": self.context, | |
| "discussion": '\n'.join(warm_start_discussion), | |
| 'num_sections': 3 | |
| }, | |
| response_model=Sections, | |
| logging_fn='generate_report_outline_instructions' | |
| ) | |
| self.task_status['generating_outline']["status"] = "done" | |
| yield from self.update_gradio(report_outline_str=report_outline.as_str) | |
| print(report_outline.as_str) | |
| return report_outline | |
| def validate_outline_with_human(self, report_outline: Sections) -> Sections: | |
| """Ask the human feedback and improve the report outline till they say 'OK' """ | |
| while True: | |
| ui.system_update("\nPlease provide feedback on the generated report outline") | |
| feedback = ui.get_multiline_input() | |
| if feedback.lower() == 'ok': | |
| return report_outline | |
| ui.system_update("\nImproving the report outline based on your feedback") | |
| extract_sections_chain = prompts.IMPROVE_REPORT_OUTLINE_PROMPT | self.llm.with_structured_output(Sections) | |
| report_outline = extract_sections_chain.invoke({"topic": self.research_task['topic'], "feedback": feedback, "report_outline": report_outline.as_str}) | |
| ui.system_output(report_outline.as_str) | |
| def create_report_synopsis(self): | |
| return call_llm( | |
| instructions=prompts.CREATE_SYNOPSIS_INSTRUCTIONS, | |
| model_type='fast', | |
| context={ | |
| "report_type": self.research_task['report_type'], | |
| "topic": self.research_task['topic'], | |
| "key_questions": self.research_task['key_questions'], | |
| }, | |
| response_model=ReportSynopsis, | |
| logging_fn='create_synopsis_instructions' | |
| ) | |
| def update_gradio(self, report_outline_str = '', button_disable = False): | |
| if report_outline_str != '': | |
| self.gradio_report_outline = report_outline_str | |
| yield [gr.update(interactive=button_disable), self.update_ui(), self.gradio_report_outline] | |
| def start_research(self): | |
| """Main research loop with comprehensive functionality""" | |
| self.task_status['synopsis_draft']["status"] = "running" | |
| yield from self.update_gradio() | |
| ui.system_update(f"Starting research on: {self.research_task['topic']}") | |
| ui.system_update("\nGenerating report outline") | |
| self.report_synopsis = self.create_report_synopsis() | |
| self.task_status['synopsis_draft']["status"] = "done" | |
| yield from self.update_gradio() | |
| self.report_outline = yield from self._generate_report_outline() | |
| #self.report_outline = self.validate_outline_with_human(self.report_outline) | |
| for section in self.report_outline.sections: | |
| self.task_status[section.name] = {"name": f"Starting Section: {section.name}", "status": "pending"} | |
| yield from self.update_gradio() | |
| ui.system_update("\nGenerating personas for writing sections") | |
| self.personas = call_llm( | |
| instructions=prompts.GENERATE_PERSONAS_INSTRUCTIONS, | |
| model_type='slow', | |
| context={ | |
| "topic": self.research_task['topic'], | |
| "report_synopsis": self.report_synopsis, | |
| 'type_of_report': self.research_task['report_type'], | |
| 'num_personas': 2}, | |
| response_model=Perspectives, | |
| logging_fn='generate_personas_instructions' | |
| ) | |
| ui.system_update("\nWriting Sections....") | |
| for section in self.report_outline.sections: | |
| self.task_status[section.name]["status"] = "running" | |
| yield from self.update_gradio() | |
| ui.system_sub_update(f"\nWriting Section: {section.name}") | |
| section = yield from self.section_writer(section) | |
| for section in self.report_outline.sections: | |
| print(section.content) | |
| def update_ui(self): | |
| completed_tasks = sum(1 for _, task in self.task_status.items() if task["status"] == "done") | |
| total_tasks = len(self.task_status) | |
| progress_percentage = int((completed_tasks / total_tasks) * 100) | |
| html_output = f""" | |
| <style> | |
| .progress-bar-container {{ | |
| width: 100%; | |
| background-color: #f3f3f3; | |
| border-radius: 5px; | |
| overflow: hidden; | |
| margin-bottom: 20px; | |
| }} | |
| .progress-bar {{ | |
| height: 20px; | |
| width: {progress_percentage}%; | |
| background-color: #3498db; | |
| transition: width 0.3s; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| color: white; | |
| font-weight: bold; | |
| font-size: 12px; | |
| }} | |
| .progress-task {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| font-family: 'Helvetica Neue', Arial, sans-serif; | |
| margin: 5px 0; | |
| font-size: 14px; | |
| font-weight: 500; | |
| color: #333; | |
| }} | |
| .progress-task .task-name {{ | |
| flex-grow: 1; | |
| }} | |
| .progress-task .icon {{ | |
| width: 20px; | |
| height: 20px; | |
| }} | |
| .loading-circle {{ | |
| width: 15px; | |
| height: 15px; | |
| border: 3px solid #ccc; | |
| border-top: 3px solid #3498db; | |
| border-radius: 50%; | |
| animation: spin 1s linear infinite; | |
| }} | |
| @keyframes spin {{ | |
| 0% {{ transform: rotate(0deg); }} | |
| 100% {{ transform: rotate(360deg); }} | |
| }} | |
| .done-icon {{ | |
| color: #2ecc71; | |
| font-size: 16px; | |
| }} | |
| .checkbox {{ | |
| width: 15px; | |
| height: 15px; | |
| border: 1px solid #ccc; | |
| display: inline-block; | |
| margin-right: 10px; | |
| }} | |
| .milestone {{ | |
| display: inline-block; | |
| width: 10px; | |
| height: 10px; | |
| background-color: #ccc; | |
| border-radius: 50%; | |
| margin: 0 5px; | |
| }} | |
| .milestone.completed {{ | |
| background-color: #2ecc71; | |
| }} | |
| </style> | |
| <div class='progress-bar-container'> | |
| <div class='progress-bar'>{progress_percentage}%</div> | |
| </div> | |
| <div style='display: flex; justify-content: center; margin-bottom: 20px;'> | |
| {''.join([f"<div class='milestone {'completed' if i < completed_tasks else ''}'></div>" for i in range(total_tasks)])} | |
| </div> | |
| """ | |
| for _, task in self.task_status.items(): | |
| if task["status"] == "running": | |
| icon = "<div class='loading-circle'></div>" | |
| elif task["status"] == "done": | |
| icon = "<span class='done-icon'>✓</span>" | |
| else: | |
| icon = "<div class='checkbox'></div>" | |
| html_output += f"<div class='progress-task'><span class='icon'>{icon}</span><span class='task-name'>{task['name']}</span></div>" | |
| return html_output | |