Nicolás Larenas commited on
Commit
1a9430f
·
verified ·
1 Parent(s): 98a619d

Update ai_model.py

Browse files
Files changed (1) hide show
  1. ai_model.py +51 -29
ai_model.py CHANGED
@@ -1,42 +1,64 @@
1
  import google.generativeai as genai
2
  import os
 
 
3
 
4
  # Load Google API key from environment
5
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
6
  genai.configure(api_key=GOOGLE_API_KEY)
7
 
8
- # Query AI model
9
- async def query_ai_model(message, history=None, system_message=None, max_new_tokens=512, temperature=0.7, top_p=0.95):
10
- try:
11
- # Build the conversation history in the required format
12
- messages = []
13
-
14
- if system_message:
15
- messages.append({'author': 'system', 'content': system_message})
16
 
17
- if history:
18
- for user_msg, bot_reply in history:
19
- if user_msg:
20
- messages.append({'author': 'user', 'content': user_msg})
21
- if bot_reply:
22
- messages.append({'author': 'assistant', 'content': bot_reply})
23
 
24
- # Append the new user message
25
- messages.append({'author': 'user', 'content': message})
 
 
 
 
 
26
 
27
- # Set parameters
28
- parameters = {
29
- 'temperature': temperature,
30
- 'top_p': top_p,
31
- 'max_output_tokens': int(max_new_tokens)
32
- }
33
-
34
- # Generate response
35
- response = genai.generate_chat(
36
- model='chat-bison-001', # Update model name as necessary
37
- messages=messages,
38
- **parameters
 
 
 
 
 
39
  )
40
- return response.candidates[0]['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
  return f"An error occurred: {str(e)}"
 
1
  import google.generativeai as genai
2
  import os
3
+ from typing import List, Dict, Union
4
+ from PIL import Image
5
 
6
  # Load Google API key from environment
7
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
8
  genai.configure(api_key=GOOGLE_API_KEY)
9
 
10
+ IMAGE_CACHE_DIRECTORY = "/tmp"
11
+ IMAGE_WIDTH = 512
 
 
 
 
 
 
12
 
13
+ def preprocess_image(image: Image.Image) -> Image.Image:
14
+ image_height = int(image.height * IMAGE_WIDTH / image.width)
15
+ return image.resize((IMAGE_WIDTH, image_height))
 
 
 
16
 
17
+ def cache_pil_image(image: Image.Image) -> str:
18
+ import uuid
19
+ image_filename = f"{uuid.uuid4()}.jpeg"
20
+ os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
21
+ image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
22
+ image.save(image_path, "JPEG")
23
+ return image_path
24
 
25
+ async def query_ai_model(
26
+ messages: List[Dict[str, Union[str, List[str]]]],
27
+ temperature: float,
28
+ max_output_tokens: int,
29
+ stop_sequences: List[str],
30
+ top_k: int,
31
+ top_p: float,
32
+ use_vision: bool = False,
33
+ image_files: List[str] = None,
34
+ ):
35
+ try:
36
+ generation_config = genai.types.GenerationConfig(
37
+ temperature=temperature,
38
+ max_output_tokens=max_output_tokens,
39
+ stop_sequences=stop_sequences if stop_sequences else None,
40
+ top_k=top_k,
41
+ top_p=top_p,
42
  )
43
+ if use_vision and image_files:
44
+ # For vision model
45
+ image_prompt = [Image.open(file).convert('RGB') for file in image_files]
46
+ text_prompt = [msg['parts'][0] for msg in messages if msg['role'] == 'user']
47
+ model = genai.GenerativeModel('gemini-pro-vision')
48
+ response = model.generate_content(
49
+ text_prompt + image_prompt,
50
+ stream=False,
51
+ generation_config=generation_config
52
+ )
53
+ else:
54
+ # For text model
55
+ model = genai.GenerativeModel('gemini-pro')
56
+ response = model.generate_content(
57
+ messages,
58
+ stream=False,
59
+ generation_config=generation_config
60
+ )
61
+ # Since we are not streaming, get the full response text
62
+ return response.result
63
  except Exception as e:
64
  return f"An error occurred: {str(e)}"