igiuseppe commited on
Commit
4bb354c
·
1 Parent(s): 2194a0e

using litellm - for now shuffle disabled

Browse files
Files changed (7) hide show
  1. .gitignore +2 -0
  2. core.py +12 -36
  3. core_dt.py +1 -1
  4. eval/compare.py +1 -1
  5. prompts.py +38 -5
  6. requirements.txt +2 -1
  7. utils.py +53 -25
.gitignore CHANGED
@@ -5,3 +5,5 @@ interview_example.txt
5
  eval/results/
6
  eval/__pycache__/
7
  eval/synthetic/
 
 
 
5
  eval/results/
6
  eval/__pycache__/
7
  eval/synthetic/
8
+ test_lite_llm.py
9
+ test_gemini.py
core.py CHANGED
@@ -12,7 +12,8 @@ from prompts import (
12
  GENERATE_REPORT_PROMPT,
13
  CHAT_WITH_REPORT_PROMPT,
14
  GENERATE_AUDIENCE_NAME_PROMPT,
15
- persona_schema
 
16
  )
17
 
18
  logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ def generate_user_parameters(audience: str, scope: str,n:int=24) -> List[str]:
35
  class Response(BaseModel):
36
  additional_parameters: list[str]
37
 
38
- response = call_llm(prompt=prompt, response_format=Response,model="gpt-4o-mini",temperature=0)
39
  additional_parameters = json.loads(response)["additional_parameters"]
40
 
41
  return standard_parameters + additional_parameters
@@ -67,14 +68,12 @@ def generate_synthetic_personas(num_personas: int, audience: str, previous_perso
67
  all_new_personas = []
68
  max_iterations = 5 # Safety break to prevent infinite loops
69
  current_iteration = 0
70
-
71
-
72
- response_format = persona_schema
73
 
74
  while len(all_new_personas) < num_personas and current_iteration < max_iterations:
75
  current_iteration += 1
76
  needed_personas = num_personas - len(all_new_personas)
77
  logger.info(f"Iteration {current_iteration}/{max_iterations}: Requesting {needed_personas} more personas (Total needed: {num_personas}, Have: {len(all_new_personas)})...")
 
78
 
79
  # Combine original previous_personas with those generated in this function's previous iterations
80
  current_context_personas = (previous_personas or []) + all_new_personas
@@ -91,7 +90,7 @@ def generate_synthetic_personas(num_personas: int, audience: str, previous_perso
91
  prompt += build_previous_personas_context(current_context_personas) # Appends the formatted list
92
 
93
  try:
94
- response_str = call_llm(prompt=prompt, response_format=response_format,temperature=1, model="gpt-4.1-mini")
95
  response_data = json.loads(response_str)
96
  users_list = response_data.get("users_personas", [])
97
 
@@ -129,7 +128,7 @@ def ask_single_question_to_persona(persona: dict, question: str) -> str:
129
  persona=persona,
130
  question=question
131
  )
132
- answer = call_llm(prompt=prompt,temperature=0, model="gpt-4.1-nano")
133
  return answer
134
  except Exception as e:
135
  logger.error(f"Error asking question '{question}' to persona {persona.get('Name', 'Unknown')}: {e}")
@@ -138,30 +137,7 @@ def ask_single_question_to_persona(persona: dict, question: str) -> str:
138
  def ask_all_questions_to_persona(persona: dict, questions: List[str]) -> str:
139
  """Asks a single question to a single persona and returns the answer."""
140
 
141
- response_format = {
142
- "type": "json_schema",
143
- "json_schema": {
144
- "name": "answers_list",
145
- "schema": {
146
- "type": "object",
147
- "properties": {
148
- "answers": {
149
- "type": "array",
150
- "description": f"A list of answers to questions, with exactly {len(questions)} elements.",
151
- "items": {
152
- "type": "string",
153
- "description": "Each answer corresponding to a question."
154
- }
155
- }
156
- },
157
- "required": [
158
- "answers"
159
- ],
160
- "additionalProperties": False
161
- },
162
- "strict": True
163
- }
164
- }
165
 
166
  try:
167
  prompt = ASK_QUESTIONS_TO_PERSONA_PROMPT.format(
@@ -169,7 +145,7 @@ def ask_all_questions_to_persona(persona: dict, questions: List[str]) -> str:
169
  questions=questions,
170
  num_questions=len(questions)
171
  )
172
- response_str = call_llm(prompt=prompt,temperature=0.5, model="gpt-4.1-mini",response_format=response_format)
173
  response_data = json.loads(response_str)
174
  answers = response_data.get("answers", [])
175
  return answers
@@ -286,7 +262,7 @@ def generate_report(questions,fleet,scope) -> str:
286
  content=content,
287
  scope=scope
288
  )
289
- report_text = call_llm(prompt=prompt,model="gpt-4.1-mini",temperature=0)
290
 
291
  return report_text
292
 
@@ -315,7 +291,7 @@ def chat_with_persona(persona: dict, question: str, conversation_history: List[d
315
  )
316
  if conversation_history:
317
  prompt += f"\nHere you have the previous conversation, make sure to answer the question in a way that is consistent with it:\n{history_context}"
318
- return call_llm(prompt=prompt,temperature=0.5, model="gpt-4.1-mini")
319
 
320
  def chat_with_report(users: List[dict], question: str, questions: List[str]) -> str:
321
  """
@@ -336,7 +312,7 @@ def chat_with_report(users: List[dict], question: str, questions: List[str]) ->
336
  content=content,
337
  question=question
338
  )
339
- return call_llm(prompt=prompt,temperature=0, model="gpt-4.1-nano")
340
 
341
  def generate_audience_name(audience: str, scope: str) -> str:
342
  """
@@ -353,4 +329,4 @@ def generate_audience_name(audience: str, scope: str) -> str:
353
  audience=audience,
354
  scope=scope
355
  )
356
- return call_llm(prompt=prompt,model="gpt-4.1-nano",temperature=0)
 
12
  GENERATE_REPORT_PROMPT,
13
  CHAT_WITH_REPORT_PROMPT,
14
  GENERATE_AUDIENCE_NAME_PROMPT,
15
+ persona_schema,
16
+ answers_schema
17
  )
18
 
19
  logger = logging.getLogger(__name__)
 
36
  class Response(BaseModel):
37
  additional_parameters: list[str]
38
 
39
+ response = call_llm(prompt=prompt, response_format=Response,model_type="mid",temperature=0)
40
  additional_parameters = json.loads(response)["additional_parameters"]
41
 
42
  return standard_parameters + additional_parameters
 
68
  all_new_personas = []
69
  max_iterations = 5 # Safety break to prevent infinite loops
70
  current_iteration = 0
 
 
 
71
 
72
  while len(all_new_personas) < num_personas and current_iteration < max_iterations:
73
  current_iteration += 1
74
  needed_personas = num_personas - len(all_new_personas)
75
  logger.info(f"Iteration {current_iteration}/{max_iterations}: Requesting {needed_personas} more personas (Total needed: {num_personas}, Have: {len(all_new_personas)})...")
76
+ response_format = persona_schema(needed_personas)
77
 
78
  # Combine original previous_personas with those generated in this function's previous iterations
79
  current_context_personas = (previous_personas or []) + all_new_personas
 
90
  prompt += build_previous_personas_context(current_context_personas) # Appends the formatted list
91
 
92
  try:
93
+ response_str = call_llm(prompt=prompt, response_format=response_format,temperature=1, model_type="mid",shuffle=False)
94
  response_data = json.loads(response_str)
95
  users_list = response_data.get("users_personas", [])
96
 
 
128
  persona=persona,
129
  question=question
130
  )
131
+ answer = call_llm(prompt=prompt,temperature=0, model_type="low",shuffle=False)
132
  return answer
133
  except Exception as e:
134
  logger.error(f"Error asking question '{question}' to persona {persona.get('Name', 'Unknown')}: {e}")
 
137
  def ask_all_questions_to_persona(persona: dict, questions: List[str]) -> str:
138
  """Asks a single question to a single persona and returns the answer."""
139
 
140
+ response_format = answers_schema(len(questions))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  try:
143
  prompt = ASK_QUESTIONS_TO_PERSONA_PROMPT.format(
 
145
  questions=questions,
146
  num_questions=len(questions)
147
  )
148
+ response_str = call_llm(prompt=prompt,temperature=0.5, model_type="mid",response_format=response_format,shuffle=False)
149
  response_data = json.loads(response_str)
150
  answers = response_data.get("answers", [])
151
  return answers
 
262
  content=content,
263
  scope=scope
264
  )
265
+ report_text = call_llm(prompt=prompt,model_type="mid",temperature=0)
266
 
267
  return report_text
268
 
 
291
  )
292
  if conversation_history:
293
  prompt += f"\nHere you have the previous conversation, make sure to answer the question in a way that is consistent with it:\n{history_context}"
294
+ return call_llm(prompt=prompt,temperature=0.5, model_type="mid",shuffle=False)
295
 
296
  def chat_with_report(users: List[dict], question: str, questions: List[str]) -> str:
297
  """
 
312
  content=content,
313
  question=question
314
  )
315
+ return call_llm(prompt=prompt,temperature=0, model_type="low")
316
 
317
  def generate_audience_name(audience: str, scope: str) -> str:
318
  """
 
329
  audience=audience,
330
  scope=scope
331
  )
332
+ return call_llm(prompt=prompt,model_type="low",temperature=0)
core_dt.py CHANGED
@@ -84,7 +84,7 @@ The result should be in plain text.
84
  Here is the text:
85
  {agent_particularities}
86
  """
87
- return call_llm(prompt=prompt)
88
 
89
  def generate_new_memory(n,person):
90
  return person.retrieve_recent_memories(include_omission_info=False)[-n:]
 
84
  Here is the text:
85
  {agent_particularities}
86
  """
87
+ return call_llm(prompt=prompt,model_type="mid",temperature=0.5)
88
 
89
  def generate_new_memory(n,person):
90
  return person.retrieve_recent_memories(include_omission_info=False)[-n:]
eval/compare.py CHANGED
@@ -104,7 +104,7 @@ The structure of your output must be a simple list of insights.
104
  """
105
 
106
  logger.info(f"Extracting insights for {audience_type} audience...")
107
- insights = call_llm(prompt=prompt,temperature=0, model="gpt-4.1-mini")
108
  logger.info(f"Successfully extracted insights for {audience_type} audience")
109
 
110
  return insights
 
104
  """
105
 
106
  logger.info(f"Extracting insights for {audience_type} audience...")
107
+ insights = call_llm(prompt=prompt,temperature=0, model_type="mid")
108
  logger.info(f"Successfully extracted insights for {audience_type} audience")
109
 
110
  return insights
prompts.py CHANGED
@@ -276,7 +276,8 @@ Create a very concise name (max one sentence) that captures the essence of this
276
  Respond with ONLY the name, nothing else.
277
  """
278
 
279
- persona_schema={
 
280
  "type": "json_schema",
281
  "json_schema": {
282
  "name": "user_personas_response",
@@ -285,10 +286,10 @@ persona_schema={
285
  "properties": {
286
  "users_personas": {
287
  "type": "array",
288
- "description": "An array of synthetic user personas.",
289
  "items": {
290
  "type": "object",
291
- "description": "A single synthetic user persona defined by 20 key parameters.",
292
  "properties": {
293
  "Name": {
294
  "type": "string",
@@ -513,7 +514,9 @@ persona_schema={
513
  "general_interests_and_hobbies"
514
  ],
515
  "additionalProperties": False
516
- }
 
 
517
  }
518
  },
519
  "required": ["users_personas"],
@@ -521,4 +524,34 @@ persona_schema={
521
  },
522
  "strict": True
523
  }
524
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  Respond with ONLY the name, nothing else.
277
  """
278
 
279
+ def persona_schema(n):
280
+ persona_schema={
281
  "type": "json_schema",
282
  "json_schema": {
283
  "name": "user_personas_response",
 
286
  "properties": {
287
  "users_personas": {
288
  "type": "array",
289
+ "description": f"An array of exactly {n} synthetic user personas.",
290
  "items": {
291
  "type": "object",
292
+ "description": "A single synthetic user persona defined by the following key parameters.",
293
  "properties": {
294
  "Name": {
295
  "type": "string",
 
514
  "general_interests_and_hobbies"
515
  ],
516
  "additionalProperties": False
517
+ },
518
+ "minItems": n,
519
+ "maxItems": n
520
  }
521
  },
522
  "required": ["users_personas"],
 
524
  },
525
  "strict": True
526
  }
527
+ }
528
+ return persona_schema
529
+
530
+ def answers_schema(n):
531
+ answers_schema={
532
+ "type": "json_schema",
533
+ "json_schema": {
534
+ "name": "answers_list",
535
+ "schema": {
536
+ "type": "object",
537
+ "properties": {
538
+ "answers": {
539
+ "type": "array",
540
+ "description": f"A list of answers to questions, with exactly {n} elements.",
541
+ "items": {
542
+ "type": "string",
543
+ "description": "Each answer corresponding to a question."
544
+ },
545
+ "minItems": n,
546
+ "maxItems": n
547
+ }
548
+ },
549
+ "required": [
550
+ "answers"
551
+ ],
552
+ "additionalProperties": False
553
+ },
554
+ "strict": True
555
+ }
556
+ }
557
+ return answers_schema
requirements.txt CHANGED
@@ -6,4 +6,5 @@ requests==2.32.3
6
  gradio==5.23.2
7
  fastapi==0.115.12
8
  uvicorn==0.34.0
9
- git+https://github.com/igiuseppe/TinyTroupeFork.git
 
 
6
  gradio==5.23.2
7
  fastapi==0.115.12
8
  uvicorn==0.34.0
9
+ git+https://github.com/igiuseppe/TinyTroupeFork.git
10
+ litellm==1.71.2
utils.py CHANGED
@@ -1,29 +1,57 @@
1
- import openai
 
 
 
 
 
 
2
 
3
- LLM_MODEL = "gpt-4.1-nano"
4
- temperature = 0.5
5
- frequency_penalty=0
6
- presence_penalty=0
7
- top_p=0
8
 
9
- def call_llm(prompt: str, response_format=None, model=LLM_MODEL,temperature=temperature,frequency_penalty=frequency_penalty,presence_penalty=presence_penalty,top_p=top_p) -> str:
10
- client = openai.OpenAI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  if response_format:
12
- response = client.beta.chat.completions.parse(
13
- model=model,
14
- messages=[{"role": "user", "content": prompt}],
15
- response_format=response_format,
16
- temperature=temperature,
17
- frequency_penalty=frequency_penalty,
18
- presence_penalty=presence_penalty,
19
- top_p=top_p
20
- )
 
 
21
  else:
22
- response = client.chat.completions.create(
23
- model=model,
24
- messages=[{"role": "user", "content": prompt}],
25
- temperature=temperature,
26
- frequency_penalty=0.0,
27
- presence_penalty=0.0
28
- )
29
- return response.choices[0].message.content
 
1
+ from litellm import completion, _turn_on_debug
2
+ from dotenv import load_dotenv
3
+ import random
4
+ import logging
5
+ load_dotenv()
6
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
7
+ logger = logging.getLogger(__name__)
8
 
9
+ models_low=["gemini/gemini-2.0-flash","openai/gpt-4.1-nano"]
10
+ models_mid=["gemini/gemini-2.5-flash-preview-05-20","openai/gpt-4.1-mini"]
11
+ models_high=["gemini/gemini-2.5-pro-preview-05-06","openai/gpt-4.1"]
 
 
12
 
13
+ model_low="openai/gpt-4.1-nano"
14
+ model_mid="openai/gpt-4.1-mini"
15
+ model_high="openai/gpt-4.1"
16
+
17
+ def call_llm(prompt: str, temperature: float,model_type: str,response_format=None,tools=None,shuffle=False,return_tokens=False) -> str:
18
+ if shuffle:
19
+ if model_type=="low":
20
+ model = random.choice(models_low)
21
+ elif model_type=="mid":
22
+ model = random.choice(models_mid)
23
+ elif model_type=="high":
24
+ model = random.choice(models_high)
25
+ logger.info(f"SHUFFLE. Using model: {model}")
26
+ else:
27
+ if model_type=="low":
28
+ model = model_low
29
+ elif model_type=="mid":
30
+ model = model_mid
31
+ elif model_type=="high":
32
+ model = model_high
33
+
34
+ messages=[
35
+ {"role": "user", "content": prompt},
36
+ ]
37
+
38
+ completion_args = {
39
+ "model": model,
40
+ "messages": messages,
41
+ "temperature": temperature
42
+ }
43
+
44
  if response_format:
45
+ completion_args["response_format"] = response_format
46
+
47
+ if tools:
48
+ completion_args["tools"] = tools
49
+
50
+ response = completion(**completion_args)
51
+ response_str = response.choices[0].message.content
52
+ if return_tokens:
53
+ output_tokens = response.usage.completion_tokens
54
+ input_tokens = response.usage.prompt_tokens
55
+ return response_str,input_tokens,output_tokens
56
  else:
57
+ return response_str