add support for Gemini (#1465)
Browse files### What problem does this PR solve?
#1036
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
- api/db/init_data.py +36 -1
- rag/llm/chat_model.py +61 -0
- rag/llm/cv_model.py +23 -0
- rag/llm/embedding_model.py +25 -1
- requirements.txt +1 -0
- requirements_arm.txt +1 -0
- requirements_dev.txt +1 -0
- web/src/assets/svg/llm/gemini.svg +114 -0
- web/src/pages/user-setting/setting-model/index.tsx +1 -0
api/db/init_data.py
CHANGED
|
@@ -175,6 +175,11 @@ factory_infos = [{
|
|
| 175 |
"logo": "",
|
| 176 |
"tags": "LLM,TEXT EMBEDDING",
|
| 177 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
}
|
| 179 |
# {
|
| 180 |
# "name": "文心一言",
|
|
@@ -898,7 +903,37 @@ def init_llm_factory():
|
|
| 898 |
"tags": "TEXT EMBEDDING",
|
| 899 |
"max_tokens": 2048,
|
| 900 |
"model_type": LLMType.EMBEDDING.value
|
| 901 |
-
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
]
|
| 903 |
for info in factory_infos:
|
| 904 |
try:
|
|
|
|
| 175 |
"logo": "",
|
| 176 |
"tags": "LLM,TEXT EMBEDDING",
|
| 177 |
"status": "1",
|
| 178 |
+
},{
|
| 179 |
+
"name": "Gemini",
|
| 180 |
+
"logo": "",
|
| 181 |
+
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
| 182 |
+
"status": "1",
|
| 183 |
}
|
| 184 |
# {
|
| 185 |
# "name": "文心一言",
|
|
|
|
| 903 |
"tags": "TEXT EMBEDDING",
|
| 904 |
"max_tokens": 2048,
|
| 905 |
"model_type": LLMType.EMBEDDING.value
|
| 906 |
+
}, {
|
| 907 |
+
"fid": factory_infos[17]["name"],
|
| 908 |
+
"llm_name": "gemini-1.5-pro-latest",
|
| 909 |
+
"tags": "LLM,CHAT,1024K",
|
| 910 |
+
"max_tokens": 1024*1024,
|
| 911 |
+
"model_type": LLMType.CHAT.value
|
| 912 |
+
}, {
|
| 913 |
+
"fid": factory_infos[17]["name"],
|
| 914 |
+
"llm_name": "gemini-1.5-flash-latest",
|
| 915 |
+
"tags": "LLM,CHAT,1024K",
|
| 916 |
+
"max_tokens": 1024*1024,
|
| 917 |
+
"model_type": LLMType.CHAT.value
|
| 918 |
+
}, {
|
| 919 |
+
"fid": factory_infos[17]["name"],
|
| 920 |
+
"llm_name": "gemini-1.0-pro",
|
| 921 |
+
"tags": "LLM,CHAT,30K",
|
| 922 |
+
"max_tokens": 30*1024,
|
| 923 |
+
"model_type": LLMType.CHAT.value
|
| 924 |
+
}, {
|
| 925 |
+
"fid": factory_infos[17]["name"],
|
| 926 |
+
"llm_name": "gemini-1.0-pro-vision-latest",
|
| 927 |
+
"tags": "LLM,IMAGE2TEXT,12K",
|
| 928 |
+
"max_tokens": 12*1024,
|
| 929 |
+
"model_type": LLMType.IMAGE2TEXT.value
|
| 930 |
+
}, {
|
| 931 |
+
"fid": factory_infos[17]["name"],
|
| 932 |
+
"llm_name": "text-embedding-004",
|
| 933 |
+
"tags": "TEXT EMBEDDING",
|
| 934 |
+
"max_tokens": 2048,
|
| 935 |
+
"model_type": LLMType.EMBEDDING.value
|
| 936 |
+
}
|
| 937 |
]
|
| 938 |
for info in factory_infos:
|
| 939 |
try:
|
rag/llm/chat_model.py
CHANGED
|
@@ -621,3 +621,64 @@ class BedrockChat(Base):
|
|
| 621 |
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
| 622 |
|
| 623 |
yield num_tokens_from_string(ans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
| 622 |
|
| 623 |
yield num_tokens_from_string(ans)
|
| 624 |
+
|
| 625 |
+
class GeminiChat(Base):
|
| 626 |
+
|
| 627 |
+
def __init__(self, key, model_name,base_url=None):
|
| 628 |
+
from google.generativeai import client,GenerativeModel
|
| 629 |
+
|
| 630 |
+
client.configure(api_key=key)
|
| 631 |
+
_client = client.get_default_generative_client()
|
| 632 |
+
self.model_name = 'models/' + model_name
|
| 633 |
+
self.model = GenerativeModel(model_name=self.model_name)
|
| 634 |
+
self.model._client = _client
|
| 635 |
+
|
| 636 |
+
def chat(self,system,history,gen_conf):
|
| 637 |
+
if system:
|
| 638 |
+
history.insert(0, {"role": "user", "parts": system})
|
| 639 |
+
if 'max_tokens' in gen_conf:
|
| 640 |
+
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
| 641 |
+
for k in list(gen_conf.keys()):
|
| 642 |
+
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
| 643 |
+
del gen_conf[k]
|
| 644 |
+
for item in history:
|
| 645 |
+
if 'role' in item and item['role'] == 'assistant':
|
| 646 |
+
item['role'] = 'model'
|
| 647 |
+
if 'content' in item :
|
| 648 |
+
item['parts'] = item.pop('content')
|
| 649 |
+
|
| 650 |
+
try:
|
| 651 |
+
response = self.model.generate_content(
|
| 652 |
+
history,
|
| 653 |
+
generation_config=gen_conf)
|
| 654 |
+
ans = response.text
|
| 655 |
+
return ans, response.usage_metadata.total_token_count
|
| 656 |
+
except Exception as e:
|
| 657 |
+
return "**ERROR**: " + str(e), 0
|
| 658 |
+
|
| 659 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 660 |
+
if system:
|
| 661 |
+
history.insert(0, {"role": "user", "parts": system})
|
| 662 |
+
if 'max_tokens' in gen_conf:
|
| 663 |
+
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
| 664 |
+
for k in list(gen_conf.keys()):
|
| 665 |
+
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
| 666 |
+
del gen_conf[k]
|
| 667 |
+
for item in history:
|
| 668 |
+
if 'role' in item and item['role'] == 'assistant':
|
| 669 |
+
item['role'] = 'model'
|
| 670 |
+
if 'content' in item :
|
| 671 |
+
item['parts'] = item.pop('content')
|
| 672 |
+
ans = ""
|
| 673 |
+
try:
|
| 674 |
+
response = self.model.generate_content(
|
| 675 |
+
history,
|
| 676 |
+
generation_config=gen_conf,stream=True)
|
| 677 |
+
for resp in response:
|
| 678 |
+
ans += resp.text
|
| 679 |
+
yield ans
|
| 680 |
+
|
| 681 |
+
except Exception as e:
|
| 682 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 683 |
+
|
| 684 |
+
yield response._chunks[-1].usage_metadata.total_token_count
|
rag/llm/cv_model.py
CHANGED
|
@@ -203,6 +203,29 @@ class XinferenceCV(Base):
|
|
| 203 |
)
|
| 204 |
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
class LocalCV(Base):
|
| 208 |
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
|
|
|
| 203 |
)
|
| 204 |
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
| 205 |
|
| 206 |
+
class GeminiCV(Base):
|
| 207 |
+
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
| 208 |
+
from google.generativeai import client,GenerativeModel
|
| 209 |
+
client.configure(api_key=key)
|
| 210 |
+
_client = client.get_default_generative_client()
|
| 211 |
+
self.model_name = model_name
|
| 212 |
+
self.model = GenerativeModel(model_name=self.model_name)
|
| 213 |
+
self.model._client = _client
|
| 214 |
+
self.lang = lang
|
| 215 |
+
|
| 216 |
+
def describe(self, image, max_tokens=2048):
|
| 217 |
+
from PIL.Image import open
|
| 218 |
+
gen_config = {'max_output_tokens':max_tokens}
|
| 219 |
+
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
| 220 |
+
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
| 221 |
+
b64 = self.image2base64(image)
|
| 222 |
+
img = open(BytesIO(base64.b64decode(b64)))
|
| 223 |
+
input = [prompt,img]
|
| 224 |
+
res = self.model.generate_content(
|
| 225 |
+
input,
|
| 226 |
+
generation_config=gen_config,
|
| 227 |
+
)
|
| 228 |
+
return res.text,res.usage_metadata.total_token_count
|
| 229 |
|
| 230 |
class LocalCV(Base):
|
| 231 |
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
rag/llm/embedding_model.py
CHANGED
|
@@ -31,7 +31,7 @@ import numpy as np
|
|
| 31 |
import asyncio
|
| 32 |
from api.utils.file_utils import get_home_cache_dir
|
| 33 |
from rag.utils import num_tokens_from_string, truncate
|
| 34 |
-
|
| 35 |
|
| 36 |
class Base(ABC):
|
| 37 |
def __init__(self, key, model_name):
|
|
@@ -419,3 +419,27 @@ class BedrockEmbed(Base):
|
|
| 419 |
|
| 420 |
return np.array(embeddings), token_count
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
import asyncio
|
| 32 |
from api.utils.file_utils import get_home_cache_dir
|
| 33 |
from rag.utils import num_tokens_from_string, truncate
|
| 34 |
+
import google.generativeai as genai
|
| 35 |
|
| 36 |
class Base(ABC):
|
| 37 |
def __init__(self, key, model_name):
|
|
|
|
| 419 |
|
| 420 |
return np.array(embeddings), token_count
|
| 421 |
|
| 422 |
+
class GeminiEmbed(Base):
|
| 423 |
+
def __init__(self, key, model_name='models/text-embedding-004',
|
| 424 |
+
**kwargs):
|
| 425 |
+
genai.configure(api_key=key)
|
| 426 |
+
self.model_name = 'models/' + model_name
|
| 427 |
+
|
| 428 |
+
def encode(self, texts: list, batch_size=32):
|
| 429 |
+
texts = [truncate(t, 2048) for t in texts]
|
| 430 |
+
token_count = sum(num_tokens_from_string(text) for text in texts)
|
| 431 |
+
result = genai.embed_content(
|
| 432 |
+
model=self.model_name,
|
| 433 |
+
content=texts,
|
| 434 |
+
task_type="retrieval_document",
|
| 435 |
+
title="Embedding of list of strings")
|
| 436 |
+
return np.array(result['embedding']),token_count
|
| 437 |
+
|
| 438 |
+
def encode_queries(self, text):
|
| 439 |
+
result = genai.embed_content(
|
| 440 |
+
model=self.model_name,
|
| 441 |
+
content=truncate(text,2048),
|
| 442 |
+
task_type="retrieval_document",
|
| 443 |
+
title="Embedding of single string")
|
| 444 |
+
token_count = num_tokens_from_string(text)
|
| 445 |
+
return np.array(result['embedding']),token_count
|
requirements.txt
CHANGED
|
@@ -147,3 +147,4 @@ markdown==3.6
|
|
| 147 |
mistralai==0.4.2
|
| 148 |
boto3==1.34.140
|
| 149 |
duckduckgo_search==6.1.9
|
|
|
|
|
|
| 147 |
mistralai==0.4.2
|
| 148 |
boto3==1.34.140
|
| 149 |
duckduckgo_search==6.1.9
|
| 150 |
+
google-generativeai==0.7.2
|
requirements_arm.txt
CHANGED
|
@@ -148,3 +148,4 @@ markdown==3.6
|
|
| 148 |
mistralai==0.4.2
|
| 149 |
boto3==1.34.140
|
| 150 |
duckduckgo_search==6.1.9
|
|
|
|
|
|
| 148 |
mistralai==0.4.2
|
| 149 |
boto3==1.34.140
|
| 150 |
duckduckgo_search==6.1.9
|
| 151 |
+
google-generativeai==0.7.2
|
requirements_dev.txt
CHANGED
|
@@ -133,3 +133,4 @@ markdown==3.6
|
|
| 133 |
mistralai==0.4.2
|
| 134 |
boto3==1.34.140
|
| 135 |
duckduckgo_search==6.1.9
|
|
|
|
|
|
| 133 |
mistralai==0.4.2
|
| 134 |
boto3==1.34.140
|
| 135 |
duckduckgo_search==6.1.9
|
| 136 |
+
google-generativeai==0.7.2
|
web/src/assets/svg/llm/gemini.svg
ADDED
|
|
web/src/pages/user-setting/setting-model/index.tsx
CHANGED
|
@@ -61,6 +61,7 @@ const IconMap = {
|
|
| 61 |
Mistral: 'mistral',
|
| 62 |
'Azure-OpenAI': 'azure',
|
| 63 |
Bedrock: 'bedrock',
|
|
|
|
| 64 |
};
|
| 65 |
|
| 66 |
const LlmIcon = ({ name }: { name: string }) => {
|
|
|
|
| 61 |
Mistral: 'mistral',
|
| 62 |
'Azure-OpenAI': 'azure',
|
| 63 |
Bedrock: 'bedrock',
|
| 64 |
+
Gemini:'gemini',
|
| 65 |
};
|
| 66 |
|
| 67 |
const LlmIcon = ({ name }: { name: string }) => {
|