Nicolás Larenas commited on
Commit
791ba42
·
verified ·
1 Parent(s): da17315

Update ai_model.py

Browse files
Files changed (1) hide show
  1. ai_model.py +61 -3
ai_model.py CHANGED
@@ -1,9 +1,12 @@
 
 
1
  import google.generativeai as genai
2
  import os
3
  import logging
4
  from config import (
5
  SYSTEM_INSTRUCTION,
6
  MODEL_NAME,
 
7
  DEFAULT_MAX_OUTPUT_TOKENS,
8
  DEFAULT_TEMPERATURE,
9
  DEFAULT_TOP_P,
@@ -22,13 +25,18 @@ if not GOOGLE_API_KEY:
22
 
23
  genai.configure(api_key=GOOGLE_API_KEY)
24
 
25
- # Initialize the model with system instruction
26
  model = genai.GenerativeModel(
27
  model_name=MODEL_NAME,
28
  system_instruction=SYSTEM_INSTRUCTION
29
  )
30
 
31
- # Query AI model
 
 
 
 
 
32
  async def query_ai_model(
33
  message,
34
  history=None,
@@ -40,7 +48,7 @@ async def query_ai_model(
40
  ):
41
  try:
42
  # Build the conversation history
43
- messages = history.copy() if history else []
44
 
45
  # Append the new user message
46
  messages.append({'role': 'user', 'content': message})
@@ -67,3 +75,53 @@ async def query_ai_model(
67
  except Exception as e:
68
  logging.error("Error in query_ai_model", exc_info=True)
69
  return {'role': 'assistant', 'content': f"An error occurred: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ai_model.py
2
+
3
  import google.generativeai as genai
4
  import os
5
  import logging
6
  from config import (
7
  SYSTEM_INSTRUCTION,
8
  MODEL_NAME,
9
+ MODEL_VISION_NAME,
10
  DEFAULT_MAX_OUTPUT_TOKENS,
11
  DEFAULT_TEMPERATURE,
12
  DEFAULT_TOP_P,
 
25
 
26
  genai.configure(api_key=GOOGLE_API_KEY)
27
 
28
+ # Initialize the models with system instruction
29
  model = genai.GenerativeModel(
30
  model_name=MODEL_NAME,
31
  system_instruction=SYSTEM_INSTRUCTION
32
  )
33
 
34
+ model_vision = genai.GenerativeModel(
35
+ model_name=MODEL_VISION_NAME,
36
+ system_instruction=SYSTEM_INSTRUCTION
37
+ )
38
+
39
+ # Query AI model for text-only input
40
  async def query_ai_model(
41
  message,
42
  history=None,
 
48
  ):
49
  try:
50
  # Build the conversation history
51
+ messages = preprocess_chat_history(history) if history else []
52
 
53
  # Append the new user message
54
  messages.append({'role': 'user', 'content': message})
 
75
  except Exception as e:
76
  logging.error("Error in query_ai_model", exc_info=True)
77
  return {'role': 'assistant', 'content': f"An error occurred: {str(e)}"}
78
+
79
+ # Query AI model for text-and-image input
80
+ async def query_ai_model_vision(
81
+ text_prompt: List[str],
82
+ image_prompt: List,
83
+ max_output_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
84
+ temperature=DEFAULT_TEMPERATURE,
85
+ top_p=DEFAULT_TOP_P,
86
+ top_k=DEFAULT_TOP_K,
87
+ stop_sequences=DEFAULT_STOP_SEQUENCES,
88
+ ):
89
+ try:
90
+ # Combine text and image prompts
91
+ combined_prompt = text_prompt + image_prompt
92
+
93
+ # Set generation configuration
94
+ generation_config = genai.types.GenerationConfig(
95
+ temperature=temperature,
96
+ top_p=top_p,
97
+ top_k=top_k,
98
+ max_output_tokens=int(max_output_tokens),
99
+ stop_sequences=stop_sequences,
100
+ )
101
+
102
+ # Generate response with vision model
103
+ response = model_vision.generate_content(
104
+ combined_prompt,
105
+ generation_config=generation_config
106
+ )
107
+
108
+ # Extract the assistant's reply
109
+ assistant_reply = {'role': 'assistant', 'content': response.text}
110
+
111
+ return assistant_reply
112
+ except Exception as e:
113
+ logging.error("Error in query_ai_model_vision", exc_info=True)
114
+ return {'role': 'assistant', 'content': f"An error occurred: {str(e)}"}
115
+
116
+ # Preprocess chat history to the required format
117
+ def preprocess_chat_history(history: List[tuple]) -> List[dict]:
118
+ messages = []
119
+ for user_message, model_message in history:
120
+ if isinstance(user_message, tuple):
121
+ # Handle image inputs if necessary
122
+ pass
123
+ elif user_message is not None:
124
+ messages.append({'role': 'user', 'content': user_message})
125
+ if model_message is not None:
126
+ messages.append({'role': 'assistant', 'content': model_message})
127
+ return messages