Xenobd commited on
Commit
9c39f4f
·
verified ·
1 Parent(s): 3336811

Update src/analyzer/openai_analyzer.py

Browse files
Files changed (1) hide show
  1. src/analyzer/openai_analyzer.py +29 -44
src/analyzer/openai_analyzer.py CHANGED
@@ -1,46 +1,31 @@
 
 
1
 
2
- from typing import List
3
- from openai import OpenAI
4
-
5
- from src.core.types import DEFAULT_OPENAI_ANALYZER, DEFAULT_SYSTEM_PROMPT, DEFAULT_USER_PROMPT
6
-
7
- from src.models.analyzer_models import AnalyzerResult
8
- from src.models.scrape_models import ScrapeResult
9
-
10
- from src.core.interface.analyzer_interface import AnalyzerInterface
11
-
12
-
13
- class OpenaiAnalyzer(AnalyzerInterface):
14
- def __init__(self, api_key, model_name = DEFAULT_OPENAI_ANALYZER):
15
- self.client = OpenAI(api_key=api_key)
16
- self.model_name = model_name
17
-
18
- def analyze_search_result(self, query: str, search_results: List[ScrapeResult]) -> AnalyzerResult:
19
- """
20
- Analyzes the provided search results based on the given query.
21
- Args:
22
- query (str): The search query string.
23
- search_results (List[ScrapeResult]): A list of search results to be analyzed.
24
- Returns:
25
- AnalyzerResult: The result of the analysis.
26
- Raises:
27
- NotImplementedError: If the method is not implemented by a subclass.
28
- """
29
  try:
30
- user_prompt = DEFAULT_USER_PROMPT.replace("query", query).replace("scrape_results", f"{search_results}")
31
- completion = self.client.beta.chat.completions.parse(model=self.model_name,
32
- messages=[
33
- {
34
- "role": "system",
35
- "content": DEFAULT_SYSTEM_PROMPT
36
- },
37
- {
38
- "role": "user",
39
- "content": user_prompt
40
- }
41
- ],
42
- response_format=AnalyzerResult)
43
- response = completion.choices[0].message.parsed
44
- return response
45
- except Exception as e:
46
- raise Exception(f"Error while analyzing search result: {str(e)}")
 
 
 
 
 
1
+ # token_utils.py
2
+ import tiktoken
3
 
4
+ class TokenCounter:
5
+ """Utility class for counting and managing tokens."""
6
+
7
+ @staticmethod
8
+ def get_encoder(model_name: str):
9
+ """Get appropriate encoder for model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
+ return tiktoken.encoding_for_model(model_name)
12
+ except:
13
+ return tiktoken.get_encoding("cl100k_base") # Default for newer models
14
+
15
+ @staticmethod
16
+ def estimate_prompt_tokens(model_name: str, system_prompt: str, user_prompt: str) -> int:
17
+ """Estimate total tokens in a prompt."""
18
+ encoder = TokenCounter.get_encoder(model_name)
19
+ return len(encoder.encode(system_prompt)) + len(encoder.encode(user_prompt))
20
+
21
+ @staticmethod
22
+ def get_model_context_limit(model_name: str) -> int:
23
+ """Get context limit for a given model."""
24
+ limits = {
25
+ "gpt-3.5-turbo": 16385,
26
+ "gpt-4": 8192,
27
+ "gpt-4-turbo": 128000,
28
+ "gpt-4o": 128000,
29
+ "gpt-4o-mini": 128000,
30
+ }
31
+ return limits.get(model_name, 8192) # Default to 8192 if unknown