Spaces:
Sleeping
Sleeping
| from typing import List, TypedDict | |
| from llm_config import get_llm_instructor, call_llm | |
| from pydantic import BaseModel, Field | |
| import ui | |
| import prompts | |
| from search import fetch_search_results, format_search_results | |
| import random | |
| import time | |
| from dotenv import load_dotenv | |
| import re | |
| load_dotenv() | |
| class RoundtableMessage(BaseModel): | |
| response: str = Field(..., title="Your response") | |
| follow_up: str = Field(..., title="Your follow-up question") | |
| next_persona: str = Field(..., title="Who you are asking the question to") | |
| class ContentState(TypedDict): | |
| previous_messages: List[dict] | |
| content: str | |
| expert_question: str | |
| iteration: int | |
| full_messages: List[str] | |
| refernces : str | |
| class Queries(BaseModel): | |
| queries : List[str] = Field(..., title="List of queries to search for") | |
| class PersonaQuestion(BaseModel): | |
| question: str = Field(..., title="Your question for the expert") | |
| class StrucutredAnswer(BaseModel): | |
| answer_response: str = Field(..., title="The response to the question with citations") | |
| references_used: List[int] = Field(..., title="The references used to answer the question") | |
| class ImproveContent: | |
| def __init__(self, section_topic, section_description, section_key_questions, personas): | |
| self.section_topic = section_topic | |
| self.section_description = section_description | |
| self.section_key_questions = section_key_questions | |
| self.client = get_llm_instructor() | |
| self.num_search_result = 1 | |
| self.num_interview_rounds = 3 | |
| self.personas = personas | |
| self.warm_start_rounds = 10 | |
| # Define the initial state | |
| def create_initial_state(self) -> ContentState: | |
| return { | |
| "expert_question": "", | |
| "iteration": 0, | |
| 'previous_messages': [], | |
| 'full_messages': [], | |
| 'references' : '' | |
| } | |
| def expert_question_generator(self, persona, state: ContentState) -> ContentState: | |
| response = call_llm( | |
| instructions=prompts.QUALITY_CHECKER_INSTRUCTIONS, | |
| additional_messages= state['previous_messages'], | |
| context={ | |
| "title_description": self.section_description + ":" + self.section_topic, | |
| "key_questions": self.section_key_questions, | |
| 'persona': persona.persona | |
| }, | |
| response_model=PersonaQuestion, | |
| logging_fn="quality_checker" | |
| ) | |
| ui.system_sub_update("-------------------") | |
| ui.system_sub_update(f'{persona.name} ({persona.role},{persona.affiliation}):') | |
| ui.system_sub_update(response.question) | |
| ui.system_sub_update("-------------------") | |
| state["expert_question"] = response.question | |
| state['previous_messages'].append({'role' : 'assistant', 'content': response.question}) | |
| state['full_messages'].append(response.question) | |
| return state | |
| def replace_references(self, text: str, references_list: List[int]) -> str: | |
| """Helper method to replace bracketed references with unique numbering.""" | |
| for idx in references_list: | |
| text = text.replace(f"[{idx}]", f"[{self.num_search_result}]") | |
| self.num_search_result += 1 | |
| return text | |
| def answer_question(self, persona, state: ContentState): | |
| queries = call_llm( | |
| instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS, | |
| model_type='fast', | |
| context={ | |
| "section_topic": self.section_topic, | |
| "expert_question": state["expert_question"], | |
| 'persona': persona.persona | |
| }, | |
| response_model=Queries, | |
| logging_fn="improve_content_create_query" | |
| ) | |
| search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status, self.section_topic, self.update_ui_fn) | |
| # Hit the search engine to fetch relevant documents | |
| if search_results_list == []: | |
| queries = call_llm( | |
| instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS, | |
| model_type='fast', | |
| context={ | |
| "section_topic": self.section_topic, | |
| "expert_question": state["expert_question"], | |
| 'persona': persona.persona | |
| }, | |
| response_model=Queries, | |
| logging_fn="improve_content_create_query_fallback" | |
| ) | |
| search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status,self.section_topic, self.update_ui_fn) | |
| response = call_llm( | |
| instructions=prompts.IMPORVE_CONTENT_ANSWER_QUERY_INSTRUCTION, | |
| model_type='rag', | |
| context={ | |
| "section_topic": self.section_topic, | |
| "expert_question": state["expert_question"], | |
| "search_results": search_results, | |
| 'persona' : persona.persona | |
| }, | |
| response_model=StrucutredAnswer, | |
| logging_fn="improve_content_answer_query" | |
| ) | |
| state["content"] =response.answer_response | |
| references_used = format_search_results([search_results_list[i-1] for i in response.references_used]) | |
| # Find all unique bracketed references in the search results | |
| bracketed_refs = re.findall(r'\[(\d+)\](?=\s*Title:)', search_results) | |
| #Replace citations[2,3,4] with [2][3][4] | |
| cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', response.answer_response) | |
| for group in cited_references_raw: | |
| nums_list = group.split(',') | |
| new_string = ''.join(f'[{n.strip()}]' for n in nums_list) | |
| old_string = f'[{group}]' | |
| response.answer_response = response.answer_response.replace(old_string, new_string) | |
| # Replace each reference number with its a unique search number | |
| for ref in bracketed_refs: | |
| search_results = search_results.replace(f'[{ref}]', f"[{self.num_search_result}]") | |
| response.answer_response = response.answer_response.replace(f'[{ref}]', f"[{self.num_search_result}]") | |
| self.num_search_result += 1 | |
| ui.system_sub_update("-------------------") | |
| ui.system_sub_update('Content:') | |
| ui.system_sub_update(response.answer_response) | |
| ui.system_sub_update("-------------------") | |
| state['previous_messages'].append({'role' : 'user', 'content' : response.answer_response}) | |
| state['full_messages'].append(response.answer_response) | |
| state['references'] = state['references'] + '\n\n' + search_results | |
| state["iteration"] += 1 | |
| return state | |
| def create_and_run_interview(self, task_status, update_ui_fn): | |
| """Runs an iterative process of generating questions and answers | |
| until the iteration limit is reached.""" | |
| self.task_status = task_status | |
| self.update_ui_fn = update_ui_fn | |
| discussion_messages = [] | |
| for persona in self.personas: | |
| ui.system_update(f"Starting discussion with : {persona.name}: {persona.role}, {persona.affiliation}") | |
| state = self.create_initial_state() | |
| while state["iteration"] <= self.num_interview_rounds: | |
| state = self.expert_question_generator(persona, state) | |
| state = yield from self.answer_question(persona, state) | |
| discussion_messages.extend(state['previous_messages']) | |
| self.final_state = state | |
| return discussion_messages | |
| def generate_final_section(self, synopsis): | |
| return '\n\n'.join(self.final_state['full_messages']), self.final_state['references'] | |
| def warm_start_discussion(self): | |
| """Warm start the discussion with existing personas""" | |
| messages = [f"{self.personas[0].name}: Hi! Let's get started!"] | |
| selected_persona = random.choice(self.personas) | |
| for _ in range(self.warm_start_rounds): | |
| # Get the last 5 messages if there are more than 5 | |
| recent_messages = messages[-5:] if len(messages) > 5 else messages | |
| message = call_llm( | |
| instructions=prompts.ROUNDTABLE_DISCUSSION_INSTRUCTIONS, | |
| model_type='fast', | |
| context={ | |
| "persona_name" : selected_persona.name, | |
| "persona_role" : selected_persona.role, | |
| "persona_affiliation" : selected_persona.affiliation, | |
| "persona_focus" : selected_persona.focus, | |
| "personas" : | |
| "\n\n".join([p.name + '\n' + p.persona for p in self.personas if p != selected_persona]), | |
| "discussion" : "\n\n".join(recent_messages) | |
| }, | |
| response_model=RoundtableMessage, | |
| logging_fn="roundtable_discussion" | |
| ) | |
| ui.system_sub_update("\n\n" + selected_persona.name + ": " + message.response + '\n' + message.follow_up) | |
| messages.append(selected_persona.name + ": " + message.response + '\n' + message.follow_up) | |
| selected_persona = [p for p in self.personas if p.name == message.next_persona][0] | |
| time.sleep(3) | |
| return messages | |
| if __name__ == "__main__": | |
| section_name = 'Glean Search in the Enterprise Search Market' | |
| section_description = 'Positioning and Competition' | |
| section_key_questions = ['how is glean positioned in the enterprise search market?', "who are the main competitors in this space?"] | |
| personas = ['\nRole: Business Analyst\nAffiliation: Enterprise Software Consultant\nDescription: Specializes in helping organizations implement and optimize AI-powered tools for improved productivity and knowledge management. Will analyze Glean and Copilot from a business user perspective.\n'] | |
| improve_content = ImproveContent(section_name, section_description, section_key_questions, personas) | |
| improved_content = improve_content.create_and_run_interview() | |
| improve_content.generate_final_section() | |
| print(improved_content) |