Nicolás Larenas commited on
Commit
bed6b6d
·
verified ·
1 Parent(s): 7966583

Update ai_model.py

Browse files
Files changed (1) hide show
  1. ai_model.py +52 -51
ai_model.py CHANGED
@@ -1,64 +1,65 @@
1
  import google.generativeai as genai
2
  import os
3
- from typing import List, Dict, Union
4
- from PIL import Image
 
 
 
 
 
 
 
5
 
6
  # Load Google API key from environment
7
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
8
- genai.configure(api_key=GOOGLE_API_KEY)
9
-
10
- IMAGE_CACHE_DIRECTORY = "/tmp"
11
- IMAGE_WIDTH = 512
12
-
13
- def preprocess_image(image: Image.Image) -> Image.Image:
14
- image_height = int(image.height * IMAGE_WIDTH / image.width)
15
- return image.resize((IMAGE_WIDTH, image_height))
16
 
17
- def cache_pil_image(image: Image.Image) -> str:
18
- import uuid
19
- image_filename = f"{uuid.uuid4()}.jpeg"
20
- os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
21
- image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
22
- image.save(image_path, "JPEG")
23
- return image_path
24
 
 
25
  async def query_ai_model(
26
- messages: List[Dict[str, Union[str, List[str]]]],
27
- temperature: float,
28
- max_output_tokens: int,
29
- stop_sequences: List[str],
30
- top_k: int,
31
- top_p: float,
32
- use_vision: bool = False,
33
- image_files: List[str] = None,
34
  ):
35
  try:
36
- generation_config = genai.types.GenerationConfig(
37
- temperature=temperature,
38
- max_output_tokens=max_output_tokens,
39
- stop_sequences=stop_sequences if stop_sequences else None,
40
- top_k=top_k,
41
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
- if use_vision and image_files:
44
- # For vision model
45
- image_prompt = [Image.open(file).convert('RGB') for file in image_files]
46
- text_prompt = [msg['parts'][0] for msg in messages if msg['role'] == 'user']
47
- model = genai.GenerativeModel('gemini-pro-vision')
48
- response = model.generate_content(
49
- text_prompt + image_prompt,
50
- stream=False,
51
- generation_config=generation_config
52
- )
53
- else:
54
- # For text model
55
- model = genai.GenerativeModel('gemini-pro')
56
- response = model.generate_content(
57
- messages,
58
- stream=False,
59
- generation_config=generation_config
60
- )
61
- # Since we are not streaming, get the full response text
62
- return response.result
63
  except Exception as e:
64
  return f"An error occurred: {str(e)}"
 
1
  import google.generativeai as genai
2
  import os
3
+ from config import (
4
+ SYSTEM_MESSAGE,
5
+ MODEL_NAME,
6
+ DEFAULT_MAX_NEW_TOKENS,
7
+ DEFAULT_TEMPERATURE,
8
+ DEFAULT_TOP_P,
9
+ DEFAULT_TOP_K,
10
+ DEFAULT_STOP_SEQUENCES,
11
+ )
12
 
13
  # Load Google API key from environment
14
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
15
+ if not GOOGLE_API_KEY:
16
+ raise ValueError("GOOGLE_API_KEY is not set. Please provide your API key.")
 
 
 
 
 
 
17
 
18
+ genai.configure(api_key=GOOGLE_API_KEY)
 
 
 
 
 
 
19
 
20
+ # Query AI model
21
  async def query_ai_model(
22
+ message,
23
+ history=None,
24
+ system_message=SYSTEM_MESSAGE,
25
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
26
+ temperature=DEFAULT_TEMPERATURE,
27
+ top_p=DEFAULT_TOP_P,
28
+ top_k=DEFAULT_TOP_K,
29
+ stop_sequences=DEFAULT_STOP_SEQUENCES,
30
  ):
31
  try:
32
+ # Build the conversation history in the required format
33
+ messages = []
34
+
35
+ if system_message:
36
+ messages.append({'role': 'system', 'content': system_message})
37
+
38
+ if history:
39
+ for user_msg, bot_reply in history:
40
+ if user_msg:
41
+ messages.append({'role': 'user', 'content': user_msg})
42
+ if bot_reply:
43
+ messages.append({'role': 'assistant', 'content': bot_reply})
44
+
45
+ # Append the new user message
46
+ messages.append({'role': 'user', 'content': message})
47
+
48
+ # Set parameters
49
+ parameters = {
50
+ 'temperature': temperature,
51
+ 'top_p': top_p,
52
+ 'top_k': top_k,
53
+ 'max_output_tokens': int(max_new_tokens),
54
+ 'stop_sequences': stop_sequences,
55
+ }
56
+
57
+ # Generate response
58
+ response = genai.generate_chat(
59
+ model=MODEL_NAME,
60
+ messages=messages,
61
+ **parameters
62
  )
63
+ return response.candidates[0]['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
  return f"An error occurred: {str(e)}"