Spaces:
Runtime error
Runtime error
| import csv | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import queue | |
| import urllib.parse | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| from enum import Enum | |
| from functools import partial | |
| from queue import Queue | |
| from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar | |
| import click | |
| import duckdb | |
| import gradio as gr | |
| import requests | |
| from bs4 import BeautifulSoup | |
| from dotenv import load_dotenv | |
| from jinja2 import BaseLoader, Environment | |
| from openai import OpenAI | |
| from pydantic import BaseModel, create_model | |
| TypeVar_BaseModel = TypeVar("TypeVar_BaseModel", bound=BaseModel) | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| default_env_file = os.path.abspath(os.path.join(script_dir, ".env")) | |
| class OutputMode(str, Enum): | |
| answer = "answer" | |
| extract = "extract" | |
| class AskSettings(BaseModel): | |
| date_restrict: int | |
| target_site: str | |
| output_language: str | |
| output_length: int | |
| url_list: List[str] | |
| inference_model_name: str | |
| hybrid_search: bool | |
| output_mode: OutputMode | |
| extract_schema_str: str | |
| def _get_logger(log_level: str) -> logging.Logger: | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(log_level) | |
| if len(logger.handlers) > 0: | |
| return logger | |
| handler = logging.StreamHandler() | |
| formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| return logger | |
| def _read_url_list(url_list_file: str) -> List[str]: | |
| if not url_list_file: | |
| return [] | |
| with open(url_list_file, "r") as f: | |
| links = f.readlines() | |
| url_list = [ | |
| link.strip() | |
| for link in links | |
| if link.strip() != "" and not link.startswith("#") | |
| ] | |
| return url_list | |
| def _read_extract_schema_str(extract_schema_file: str) -> str: | |
| if not extract_schema_file: | |
| return "" | |
| with open(extract_schema_file, "r") as f: | |
| schema_str = f.read() | |
| return schema_str | |
| def _output_csv(result_dict: Dict[str, List[BaseModel]], key_name: str) -> str: | |
| # generate the CSV content from a Dict of URL and list of extracted items | |
| output = io.StringIO() | |
| csv_writer = None | |
| for src_url, items in result_dict.items(): | |
| for item in items: | |
| value_dict = item.model_dump() | |
| item_with_url = {**value_dict, key_name: src_url} | |
| if csv_writer is None: | |
| headers = list(value_dict.keys()) + [key_name] | |
| csv_writer = csv.DictWriter(output, fieldnames=headers) | |
| csv_writer.writeheader() | |
| csv_writer.writerow(item_with_url) | |
| csv_content = output.getvalue() | |
| output.close() | |
| return csv_content | |
| class Ask: | |
| def __init__(self, logger: Optional[logging.Logger] = None): | |
| self.read_env_variables() | |
| if logger is not None: | |
| self.logger = logger | |
| else: | |
| self.logger = _get_logger("INFO") | |
| self.db_con = duckdb.connect(":memory:") | |
| self.db_con.install_extension("vss") | |
| self.db_con.load_extension("vss") | |
| self.db_con.install_extension("fts") | |
| self.db_con.load_extension("fts") | |
| self.db_con.sql("CREATE SEQUENCE seq_docid START 1000") | |
| self.session = requests.Session() | |
| user_agent: str = ( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) " | |
| "Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" | |
| ) | |
| self.session.headers.update({"User-Agent": user_agent}) | |
| def read_env_variables(self) -> None: | |
| err_msg = "" | |
| self.search_api_key = os.environ.get("SEARCH_API_KEY") | |
| if self.search_api_key is None: | |
| err_msg += "SEARCH_API_KEY env variable not set.\n" | |
| self.search_project_id = os.environ.get("SEARCH_PROJECT_KEY") | |
| if self.search_project_id is None: | |
| err_msg += "SEARCH_PROJECT_KEY env variable not set.\n" | |
| self.llm_api_key = os.environ.get("LLM_API_KEY") | |
| if self.llm_api_key is None: | |
| err_msg += "LLM_API_KEY env variable not set.\n" | |
| if err_msg != "": | |
| raise Exception(f"\n{err_msg}\n") | |
| self.llm_base_url = os.environ.get("LLM_BASE_URL") | |
| if self.llm_base_url is None: | |
| self.llm_base_url = "https://api.openai.com/v1" | |
| self.embedding_model = os.environ.get("EMBEDDING_MODEL") | |
| self.embedding_dimensions = os.environ.get("EMBEDDING_DIMENSIONS") | |
| if self.embedding_model is None or self.embedding_dimensions is None: | |
| self.embedding_model = "text-embedding-3-small" | |
| self.embedding_dimensions = 1536 | |
| def search_web(self, query: str, settings: AskSettings) -> List[str]: | |
| escaped_query = urllib.parse.quote(query) | |
| url_base = ( | |
| f"https://www.googleapis.com/customsearch/v1?key={self.search_api_key}" | |
| f"&cx={self.search_project_id}&q={escaped_query}" | |
| ) | |
| url_paras = f"&safe=active" | |
| if settings.date_restrict > 0: | |
| url_paras += f"&dateRestrict={settings.date_restrict}" | |
| if settings.target_site: | |
| url_paras += f"&siteSearch={settings.target_site}&siteSearchFilter=i" | |
| url = f"{url_base}{url_paras}" | |
| self.logger.debug(f"Searching for query: {query}") | |
| resp = requests.get(url) | |
| if resp is None: | |
| raise Exception("No response from search API") | |
| search_results_dict = json.loads(resp.text) | |
| if "error" in search_results_dict: | |
| raise Exception( | |
| f"Error in search API response: {search_results_dict['error']}" | |
| ) | |
| if "searchInformation" not in search_results_dict: | |
| raise Exception( | |
| f"No search information in search API response: {resp.text}" | |
| ) | |
| total_results = search_results_dict["searchInformation"].get("totalResults", 0) | |
| if total_results == 0: | |
| self.logger.warning(f"No results found for query: {query}") | |
| return [] | |
| results = search_results_dict.get("items", []) | |
| if results is None or len(results) == 0: | |
| self.logger.warning(f"No result items in the response for query: {query}") | |
| return [] | |
| found_links = [] | |
| for result in results: | |
| link = result.get("link", None) | |
| if link is None or link == "": | |
| self.logger.warning(f"Search result link missing: {result}") | |
| continue | |
| found_links.append(link) | |
| return found_links | |
| def _scape_url(self, url: str) -> Tuple[str, str]: | |
| self.logger.info(f"Scraping {url} ...") | |
| try: | |
| response = self.session.get(url, timeout=10) | |
| soup = BeautifulSoup(response.content, "lxml", from_encoding="utf-8") | |
| body_tag = soup.body | |
| if body_tag: | |
| body_text = body_tag.get_text() | |
| body_text = " ".join(body_text.split()).strip() | |
| self.logger.debug(f"Scraped {url}: {body_text}...") | |
| if len(body_text) > 100: | |
| self.logger.info( | |
| f"✅ Successfully scraped {url} with length: {len(body_text)}" | |
| ) | |
| return url, body_text | |
| else: | |
| self.logger.warning( | |
| f"Body text too short for url: {url}, length: {len(body_text)}" | |
| ) | |
| return url, "" | |
| else: | |
| self.logger.warning(f"No body tag found in the response for url: {url}") | |
| return url, "" | |
| except Exception as e: | |
| self.logger.error(f"Scraping error {url}: {e}") | |
| return url, "" | |
| def scrape_urls(self, urls: List[str]) -> Dict[str, str]: | |
| # the key is the url and the value is the body text | |
| scrape_results: Dict[str, str] = {} | |
| partial_scrape = partial(self._scape_url) | |
| with ThreadPoolExecutor(max_workers=10) as executor: | |
| results = executor.map(partial_scrape, urls) | |
| for url, body_text in results: | |
| if body_text != "": | |
| scrape_results[url] = body_text | |
| return scrape_results | |
| def chunk_results( | |
| self, scrape_results: Dict[str, str], size: int, overlap: int | |
| ) -> Dict[str, List[str]]: | |
| chunking_results: Dict[str, List[str]] = {} | |
| for url, text in scrape_results.items(): | |
| chunks = [] | |
| for pos in range(0, len(text), size - overlap): | |
| chunks.append(text[pos : pos + size]) | |
| chunking_results[url] = chunks | |
| return chunking_results | |
| def get_embedding(self, client: OpenAI, texts: List[str]) -> List[List[float]]: | |
| if len(texts) == 0: | |
| return [] | |
| response = client.embeddings.create(input=texts, model=self.embedding_model) | |
| embeddings = [] | |
| for i in range(len(response.data)): | |
| embeddings.append(response.data[i].embedding) | |
| return embeddings | |
| def batch_get_embedding( | |
| self, client: OpenAI, chunk_batch: Tuple[str, List[str]] | |
| ) -> Tuple[Tuple[str, List[str]], List[List[float]]]: | |
| """ | |
| Return the chunk_batch as well as the embeddings for each chunk so that | |
| we can aggregate them and save them to the database together. | |
| Args: | |
| - client: OpenAI client | |
| - chunk_batch: Tuple of URL and list of chunks scraped from the URL | |
| Returns: | |
| - Tuple of chunk_bach and list of result embeddings | |
| """ | |
| texts = chunk_batch[1] | |
| embeddings = self.get_embedding(client, texts) | |
| return chunk_batch, embeddings | |
| def _create_table(self) -> str: | |
| # Simple ways to get a unique table name | |
| timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") | |
| table_name = f"document_chunks_{timestamp}" | |
| self.db_con.execute( | |
| f""" | |
| CREATE TABLE {table_name} ( | |
| doc_id INTEGER PRIMARY KEY DEFAULT nextval('seq_docid'), | |
| url TEXT, | |
| chunk TEXT, | |
| vec FLOAT[{self.embedding_dimensions}] | |
| ); | |
| """ | |
| ) | |
| return table_name | |
| def save_chunks_to_db(self, chunking_results: Dict[str, List[str]]) -> str: | |
| """ | |
| The key of chunking_results is the URL and the value is the list of chunks. | |
| """ | |
| client = self._get_api_client() | |
| embed_batch_size = 50 | |
| query_batch_size = 100 | |
| insert_data = [] | |
| table_name = self._create_table() | |
| batches: List[Tuple[str, List[str]]] = [] | |
| for url, list_chunks in chunking_results.items(): | |
| for i in range(0, len(list_chunks), embed_batch_size): | |
| list_chunks = list_chunks[i : i + embed_batch_size] | |
| batches.append((url, list_chunks)) | |
| self.logger.info(f"Embedding {len(batches)} batches of chunks ...") | |
| partial_get_embedding = partial(self.batch_get_embedding, client) | |
| with ThreadPoolExecutor(max_workers=10) as executor: | |
| all_embeddings = executor.map(partial_get_embedding, batches) | |
| self.logger.info(f"✅ Finished embedding.") | |
| # we batch the insert data to speed up the insertion operation | |
| # although the DuckDB doc says executeMany is optimized for batch insert | |
| # but we found that it is faster to batch the insert data and run a single insert | |
| for chunk_batch, embeddings in all_embeddings: | |
| url = chunk_batch[0] | |
| list_chunks = chunk_batch[1] | |
| insert_data.extend( | |
| [ | |
| (url.replace("'", " "), chunk.replace("'", " "), embedding) | |
| for chunk, embedding in zip(list_chunks, embeddings) | |
| ] | |
| ) | |
| for i in range(0, len(insert_data), query_batch_size): | |
| value_str = ", ".join( | |
| [ | |
| f"('{url}', '{chunk}', {embedding})" | |
| for url, chunk, embedding in insert_data[i : i + embed_batch_size] | |
| ] | |
| ) | |
| query = f""" | |
| INSERT INTO {table_name} (url, chunk, vec) VALUES {value_str}; | |
| """ | |
| self.db_con.execute(query) | |
| self.db_con.execute( | |
| f""" | |
| CREATE INDEX {table_name}_cos_idx ON {table_name} USING HNSW (vec) | |
| WITH (metric = 'cosine'); | |
| """ | |
| ) | |
| self.logger.info(f"✅ Created the vector index ...") | |
| self.db_con.execute( | |
| f""" | |
| PRAGMA create_fts_index( | |
| {table_name}, 'doc_id', 'chunk' | |
| ); | |
| """ | |
| ) | |
| self.logger.info(f"✅ Created the full text search index ...") | |
| return table_name | |
| def vector_search( | |
| self, table_name: str, query: str, settings: AskSettings | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| The return value is a list of {url: str, chunk: str} records. | |
| In a real world, we will define a class of Chunk to have more metadata such as offsets. | |
| """ | |
| client = self._get_api_client() | |
| embeddings = self.get_embedding(client, [query])[0] | |
| query_result: duckdb.DuckDBPyRelation = self.db_con.sql( | |
| f""" | |
| SELECT * FROM {table_name} | |
| ORDER BY array_distance(vec, {embeddings}::FLOAT[{self.embedding_dimensions}]) | |
| LIMIT 10; | |
| """ | |
| ) | |
| self.logger.debug(query_result) | |
| # use a dict to remove duplicates from vector search and full-text search | |
| matched_chunks_dict = {} | |
| for vec_result in query_result.fetchall(): | |
| doc_id = vec_result[0] | |
| result_record = { | |
| "url": vec_result[1], | |
| "chunk": vec_result[2], | |
| } | |
| matched_chunks_dict[doc_id] = result_record | |
| if settings.hybrid_search: | |
| self.logger.info("Running full-text search ...") | |
| self.db_con.execute( | |
| f""" | |
| PREPARE fts_query AS ( | |
| WITH scored_docs AS ( | |
| SELECT *, fts_main_{table_name}.match_bm25( | |
| doc_id, ?, fields := 'chunk' | |
| ) AS score FROM {table_name}) | |
| SELECT doc_id, url, chunk, score | |
| FROM scored_docs | |
| WHERE score IS NOT NULL | |
| ORDER BY score DESC | |
| LIMIT 10) | |
| """ | |
| ) | |
| self.db_con.execute("PRAGMA threads=4") | |
| # You can run more complex query rewrite methods here | |
| # usually: stemming, stop words, etc. | |
| escaped_query = query.replace("'", " ") | |
| fts_result: duckdb.DuckDBPyRelation = self.db_con.execute( | |
| f"EXECUTE fts_query('{escaped_query}')" | |
| ) | |
| index = 0 | |
| for fts_record in fts_result.fetchall(): | |
| index += 1 | |
| self.logger.debug(f"The full text search record #{index}: {fts_record}") | |
| doc_id = fts_record[0] | |
| result_record = { | |
| "url": fts_record[1], | |
| "chunk": fts_record[2], | |
| } | |
| # You can configure the score threashold and top-k | |
| if fts_record[3] > 1: | |
| matched_chunks_dict[doc_id] = result_record | |
| else: | |
| break | |
| if index >= 10: | |
| break | |
| return matched_chunks_dict.values() | |
| def _get_api_client(self) -> OpenAI: | |
| return OpenAI(api_key=self.llm_api_key, base_url=self.llm_base_url) | |
| def _render_template(self, template_str: str, variables: Dict[str, Any]) -> str: | |
| env = Environment(loader=BaseLoader(), autoescape=False) | |
| template = env.from_string(template_str) | |
| return template.render(variables) | |
| def _get_target_class(self, extract_schema_str: str) -> TypeVar_BaseModel: | |
| local_namespace = {"BaseModel": BaseModel} | |
| exec(extract_schema_str, local_namespace, local_namespace) | |
| for key, value in local_namespace.items(): | |
| if key == "__builtins__": | |
| continue | |
| if key == "BaseModel": | |
| continue | |
| if isinstance(value, type): | |
| if issubclass(value, BaseModel): | |
| return value | |
| raise Exception("No Pydantic schema found in the extract schema str.") | |
| def run_inference( | |
| self, | |
| query: str, | |
| matched_chunks: List[Dict[str, Any]], | |
| settings: AskSettings, | |
| ) -> str: | |
| system_prompt = ( | |
| "You are an expert summarizing the answers based on the provided contents." | |
| ) | |
| user_promt_template = """ | |
| Given the context as a sequence of references with a reference id in the | |
| format of a leading [x], please answer the following question using {{ language }}: | |
| {{ query }} | |
| In the answer, use format [1], [2], ..., [n] in line where the reference is used. | |
| For example, "According to the research from Google[3], ...". | |
| Please create the answer strictly related to the context. If the context has no | |
| information about the query, please write "No related information found in the context." | |
| using {{ language }}. | |
| {{ length_instructions }} | |
| Here is the context: | |
| {{ context }} | |
| """ | |
| context = "" | |
| for i, chunk in enumerate(matched_chunks): | |
| context += f"[{i+1}] {chunk['chunk']}\n" | |
| if not settings.output_length: | |
| length_instructions = "" | |
| else: | |
| length_instructions = ( | |
| f"Please provide the answer in { settings.output_length } words." | |
| ) | |
| user_prompt = self._render_template( | |
| user_promt_template, | |
| { | |
| "query": query, | |
| "context": context, | |
| "language": settings.output_language, | |
| "length_instructions": length_instructions, | |
| }, | |
| ) | |
| self.logger.debug( | |
| f"Running inference with model: {settings.inference_model_name}" | |
| ) | |
| self.logger.debug(f"Final user prompt: {user_prompt}") | |
| api_client = self._get_api_client() | |
| completion = api_client.chat.completions.create( | |
| model=settings.inference_model_name, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_prompt, | |
| }, | |
| ], | |
| ) | |
| if completion is None: | |
| raise Exception("No completion from the API") | |
| response_str = completion.choices[0].message.content | |
| return response_str | |
| def run_extract( | |
| self, | |
| query: str, | |
| extract_schema_str: str, | |
| target_content: str, | |
| settings: AskSettings, | |
| ) -> List[TypeVar_BaseModel]: | |
| target_class = self._get_target_class(extract_schema_str) | |
| system_prompt = ( | |
| "You are an expert of extract structual information from the document." | |
| ) | |
| user_promt_template = """ | |
| Given the provided content, if it contains information about {{ query }}, please extract the | |
| list of structured data items as defined in the following Pydantic schema: | |
| {{ extract_schema_str }} | |
| Below is the provided content: | |
| {{ content }} | |
| """ | |
| user_prompt = self._render_template( | |
| user_promt_template, | |
| { | |
| "query": query, | |
| "content": target_content, | |
| "extract_schema_str": extract_schema_str, | |
| }, | |
| ) | |
| self.logger.debug( | |
| f"Running extraction with model: {settings.inference_model_name}" | |
| ) | |
| self.logger.debug(f"Final user prompt: {user_prompt}") | |
| class_name = target_class.__name__ | |
| list_class_name = f"{class_name}_list" | |
| response_pydantic_model = create_model( | |
| list_class_name, | |
| items=(List[target_class], ...), | |
| ) | |
| api_client = self._get_api_client() | |
| completion = api_client.beta.chat.completions.parse( | |
| model=settings.inference_model_name, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_prompt, | |
| }, | |
| ], | |
| response_format=response_pydantic_model, | |
| ) | |
| if completion is None: | |
| raise Exception("No completion from the API") | |
| message = completion.choices[0].message | |
| if message.refusal: | |
| raise Exception( | |
| f"Refused to extract information from the document: {message.refusal}." | |
| ) | |
| extract_result = message.parsed | |
| return extract_result.items | |
| def run_query_gradio( | |
| self, | |
| query: str, | |
| date_restrict: int, | |
| target_site: str, | |
| output_language: str, | |
| output_length: int, | |
| url_list_str: str, | |
| inference_model_name: str, | |
| hybrid_search: bool, | |
| output_mode_str: str, | |
| extract_schema_str: str, | |
| ) -> Generator[Tuple[str, str], None, Tuple[str, str]]: | |
| logger = self.logger | |
| log_queue = Queue() | |
| if url_list_str: | |
| url_list = url_list_str.split("\n") | |
| else: | |
| url_list = [] | |
| settings = AskSettings( | |
| date_restrict=date_restrict, | |
| target_site=target_site, | |
| output_language=output_language, | |
| output_length=output_length, | |
| url_list=url_list, | |
| inference_model_name=inference_model_name, | |
| hybrid_search=hybrid_search, | |
| output_mode=OutputMode(output_mode_str), | |
| extract_schema_str=extract_schema_str, | |
| ) | |
| queue_handler = logging.Handler() | |
| formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
| queue_handler.emit = lambda record: log_queue.put(formatter.format(record)) | |
| logger.addHandler(queue_handler) | |
| def update_logs(): | |
| logs = [] | |
| while True: | |
| try: | |
| log = log_queue.get_nowait() | |
| logs.append(log) | |
| except queue.Empty: | |
| break | |
| return "\n".join(logs) | |
| # wrap the process in a generator to yield the logs to integrate with GradIO | |
| def process_with_logs(): | |
| if len(settings.url_list) > 0: | |
| links = settings.url_list | |
| else: | |
| logger.info("Searching the web ...") | |
| yield "", update_logs() | |
| links = self.search_web(query, settings) | |
| logger.info(f"✅ Found {len(links)} links for query: {query}") | |
| for i, link in enumerate(links): | |
| logger.debug(f"{i+1}. {link}") | |
| yield "", update_logs() | |
| logger.info("Scraping the URLs ...") | |
| yield "", update_logs() | |
| scrape_results = self.scrape_urls(links) | |
| logger.info(f"✅ Scraped {len(scrape_results)} URLs.") | |
| yield "", update_logs() | |
| if settings.output_mode == OutputMode.answer: | |
| logger.info("Chunking the text ...") | |
| yield "", update_logs() | |
| chunking_results = self.chunk_results(scrape_results, 1000, 100) | |
| total_chunks = 0 | |
| for url, chunks in chunking_results.items(): | |
| logger.debug(f"URL: {url}") | |
| total_chunks += len(chunks) | |
| for i, chunk in enumerate(chunks): | |
| logger.debug(f"Chunk {i+1}: {chunk}") | |
| logger.info(f"✅ Generated {total_chunks} chunks ...") | |
| yield "", update_logs() | |
| logger.info(f"Saving {total_chunks} chunks to DB ...") | |
| yield "", update_logs() | |
| table_name = self.save_chunks_to_db(chunking_results) | |
| logger.info(f"✅ Successfully embedded and saved chunks to DB.") | |
| yield "", update_logs() | |
| logger.info("Querying the vector DB to get context ...") | |
| matched_chunks = self.vector_search(table_name, query, settings) | |
| for i, result in enumerate(matched_chunks): | |
| logger.debug(f"{i+1}. {result}") | |
| logger.info(f"✅ Got {len(matched_chunks)} matched chunks.") | |
| yield "", update_logs() | |
| logger.info("Running inference with context ...") | |
| yield "", update_logs() | |
| answer = self.run_inference( | |
| query=query, | |
| matched_chunks=matched_chunks, | |
| settings=settings, | |
| ) | |
| logger.info("✅ Finished inference API call.") | |
| logger.info("Generating output ...") | |
| yield "", update_logs() | |
| answer = f"# Answer\n\n{answer}\n" | |
| references = "\n".join( | |
| [ | |
| f"[{i+1}] {result['url']}" | |
| for i, result in enumerate(matched_chunks) | |
| ] | |
| ) | |
| yield f"{answer}\n\n# References\n\n{references}", update_logs() | |
| elif settings.output_mode == OutputMode.extract: | |
| logger.info("Extracting structured data ...") | |
| yield "", update_logs() | |
| aggregated_output = {} | |
| for url, text in scrape_results.items(): | |
| items = self.run_extract( | |
| query=query, | |
| extract_schema_str=extract_schema_str, | |
| target_content=text, | |
| settings=settings, | |
| ) | |
| self.logger.info( | |
| f"✅ Finished inference API call. Extracted {len(items)} items from {url}." | |
| ) | |
| yield "", update_logs() | |
| self.logger.debug(items) | |
| aggregated_output[url] = items | |
| logger.info("✅ Finished extraction from all urls.") | |
| logger.info("Generating output ...") | |
| yield "", update_logs() | |
| answer = _output_csv(aggregated_output, "SourceURL") | |
| yield f"{answer}", update_logs() | |
| else: | |
| raise Exception(f"Invalid output mode: {settings.output_mode}") | |
| logs = "" | |
| final_result = "" | |
| try: | |
| for result, log_update in process_with_logs(): | |
| logs += log_update + "\n" | |
| final_result = result | |
| yield final_result, logs | |
| finally: | |
| logger.removeHandler(queue_handler) | |
| return final_result, logs | |
| def run_query( | |
| self, | |
| query: str, | |
| settings: AskSettings, | |
| ) -> str: | |
| url_list_str = "\n".join(settings.url_list) | |
| for result, logs in self.run_query_gradio( | |
| query=query, | |
| date_restrict=settings.date_restrict, | |
| target_site=settings.target_site, | |
| output_language=settings.output_language, | |
| output_length=settings.output_length, | |
| url_list_str=url_list_str, | |
| inference_model_name=settings.inference_model_name, | |
| hybrid_search=settings.hybrid_search, | |
| output_mode_str=settings.output_mode, | |
| extract_schema_str=settings.extract_schema_str, | |
| ): | |
| final_result = result | |
| return final_result | |
| def launch_gradio( | |
| query: str, | |
| init_settings: AskSettings, | |
| share_ui: bool, | |
| logger: logging.Logger, | |
| ) -> None: | |
| ask = Ask(logger=logger) | |
| def toggle_schema_textbox(option): | |
| if option == "extract": | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Ask.py - Web Search-Extract-Summarize") | |
| gr.Markdown( | |
| "Search the web with the query and summarize the results. Source code: https://github.com/pengfeng/ask.py" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Query", value=query) | |
| output_mode_input = gr.Radio( | |
| label="Output Mode [answer: simple answer, extract: get structured data]", | |
| choices=["answer", "extract"], | |
| value=init_settings.output_mode, | |
| ) | |
| extract_schema_input = gr.Textbox( | |
| label="Extract Pydantic Schema", | |
| visible=(init_settings.output_mode == "extract"), | |
| value=init_settings.extract_schema_str, | |
| lines=5, | |
| max_lines=20, | |
| ) | |
| output_mode_input.change( | |
| fn=toggle_schema_textbox, | |
| inputs=output_mode_input, | |
| outputs=extract_schema_input, | |
| ) | |
| date_restrict_input = gr.Number( | |
| label="Date Restrict (Optional) [0 or empty means no date limit.]", | |
| value=init_settings.date_restrict, | |
| ) | |
| target_site_input = gr.Textbox( | |
| label="Target Sites (Optional) [Empty means searching the whole web.]", | |
| value=init_settings.target_site, | |
| ) | |
| output_language_input = gr.Textbox( | |
| label="Output Language (Optional) [Default is English.]", | |
| value=init_settings.output_language, | |
| ) | |
| output_length_input = gr.Number( | |
| label="Output Length in words (Optional) [Default is automatically decided by LLM.]", | |
| value=init_settings.output_length, | |
| ) | |
| url_list_input = gr.Textbox( | |
| label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]", | |
| lines=5, | |
| max_lines=20, | |
| value="\n".join(init_settings.url_list), | |
| ) | |
| with gr.Accordion("More Options", open=False): | |
| hybrid_search_input = gr.Checkbox( | |
| label="Hybrid Search [Use both vector search and full-text search.]", | |
| value=init_settings.hybrid_search, | |
| ) | |
| inference_model_name_input = gr.Textbox( | |
| label="Inference Model Name", | |
| value=init_settings.inference_model_name, | |
| ) | |
| submit_button = gr.Button("Submit") | |
| with gr.Column(): | |
| answer_output = gr.Textbox(label="Answer") | |
| logs_output = gr.Textbox(label="Logs", lines=10) | |
| submit_button.click( | |
| fn=ask.run_query_gradio, | |
| inputs=[ | |
| query_input, | |
| date_restrict_input, | |
| target_site_input, | |
| output_language_input, | |
| output_length_input, | |
| url_list_input, | |
| inference_model_name_input, | |
| hybrid_search_input, | |
| output_mode_input, | |
| extract_schema_input, | |
| ], | |
| outputs=[answer_output, logs_output], | |
| ) | |
| demo.queue().launch(share=share_ui) | |
| def search_extract_summarize( | |
| query: str, | |
| output_mode: str, | |
| date_restrict: int, | |
| target_site: str, | |
| output_language: str, | |
| output_length: int, | |
| url_list_file: str, | |
| extract_schema_file: str, | |
| inference_model_name: str, | |
| hybrid_search: bool, | |
| web_ui: bool, | |
| log_level: str, | |
| ): | |
| load_dotenv(dotenv_path=default_env_file, override=False) | |
| logger = _get_logger(log_level) | |
| if output_mode == "extract" and not extract_schema_file: | |
| raise Exception("Extract mode requires the --extract-schema-file argument.") | |
| settings = AskSettings( | |
| date_restrict=date_restrict, | |
| target_site=target_site, | |
| output_language=output_language, | |
| output_length=output_length, | |
| url_list=_read_url_list(url_list_file), | |
| inference_model_name=inference_model_name, | |
| hybrid_search=hybrid_search, | |
| output_mode=OutputMode(output_mode), | |
| extract_schema_str=_read_extract_schema_str(extract_schema_file), | |
| ) | |
| if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false": | |
| if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true": | |
| share_ui = True | |
| else: | |
| share_ui = False | |
| launch_gradio( | |
| query=query, | |
| init_settings=settings, | |
| share_ui=share_ui, | |
| logger=logger, | |
| ) | |
| else: | |
| if query is None: | |
| raise Exception("Query is required for the command line mode") | |
| ask = Ask(logger=logger) | |
| final_result = ask.run_query(query=query, settings=settings) | |
| click.echo(final_result) | |
| if __name__ == "__main__": | |
| search_extract_summarize() | |