|
|
import gradio as gr |
|
|
import google.generativeai as genai |
|
|
import base64 |
|
|
from PIL import Image |
|
|
import io |
|
|
import time |
|
|
|
|
|
def encode_image(image): |
|
|
if isinstance(image, dict) and 'path' in image: |
|
|
image_path = image['path'] |
|
|
elif isinstance(image, str): |
|
|
image_path = image |
|
|
else: |
|
|
raise ValueError("Unsupported image format") |
|
|
|
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
def bot_streaming(message, history, api_key, model, system_prompt, temperature, max_tokens, top_p, top_k, harassment, hate_speech, sexually_explicit, dangerous_content): |
|
|
genai.configure(api_key=api_key) |
|
|
|
|
|
messages = [] |
|
|
images = [] |
|
|
|
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
for i, msg in enumerate(history): |
|
|
if isinstance(msg[0], tuple): |
|
|
|
|
|
image, text = msg[0] |
|
|
base64_image = encode_image(image) |
|
|
messages.append({ |
|
|
"role": "user", |
|
|
"parts": [ |
|
|
{"text": text}, |
|
|
{"inline_data": {"mime_type": "image/jpeg", "data": base64_image}} |
|
|
] |
|
|
}) |
|
|
images.append(Image.open(image['path'] if isinstance(image, dict) else image).convert("RGB")) |
|
|
else: |
|
|
|
|
|
messages.append({"role": "user", "parts": [{"text": str(msg[0])}]}) |
|
|
|
|
|
|
|
|
messages.append({"role": "model", "parts": [{"text": str(msg[1])}]}) |
|
|
|
|
|
|
|
|
if isinstance(message, dict) and "files" in message and message["files"]: |
|
|
|
|
|
image = message["files"][0] |
|
|
base64_image = encode_image(image) |
|
|
content = [ |
|
|
{"text": message["text"]}, |
|
|
{"inline_data": {"mime_type": "image/jpeg", "data": base64_image}} |
|
|
] |
|
|
images.append(Image.open(image['path'] if isinstance(image, dict) else image).convert("RGB")) |
|
|
else: |
|
|
|
|
|
content = [{"text": message["text"] if isinstance(message, dict) else str(message)}] |
|
|
|
|
|
messages.append({"role": "user", "parts": content}) |
|
|
|
|
|
model = genai.GenerativeModel(model_name=model) |
|
|
|
|
|
safety_settings = [ |
|
|
{"category": genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": getattr(genai.types.HarmBlockThreshold, harassment)}, |
|
|
{"category": genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": getattr(genai.types.HarmBlockThreshold, hate_speech)}, |
|
|
{"category": genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": getattr(genai.types.HarmBlockThreshold, sexually_explicit)}, |
|
|
{"category": genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": getattr(genai.types.HarmBlockThreshold, dangerous_content)} |
|
|
] |
|
|
|
|
|
chat = model.start_chat(history=messages) |
|
|
|
|
|
response = chat.send_message( |
|
|
content, |
|
|
stream=True, |
|
|
generation_config=genai.types.GenerationConfig( |
|
|
temperature=temperature, |
|
|
max_output_tokens=max_tokens, |
|
|
top_p=top_p, |
|
|
top_k=top_k |
|
|
), |
|
|
safety_settings=safety_settings |
|
|
) |
|
|
|
|
|
buffer = "" |
|
|
for chunk in response: |
|
|
if hasattr(chunk, 'text') and chunk.text: |
|
|
buffer += chunk.text |
|
|
yield buffer |
|
|
time.sleep(0.01) |
|
|
if hasattr(chunk, 'finish_reason') and chunk.finish_reason: |
|
|
break |
|
|
|
|
|
if buffer: |
|
|
yield buffer |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("#๐ฌ Chat with Google Gemini AI ") |
|
|
gr.Markdown("### Upload images or type your message to start the conversation.") |
|
|
|
|
|
api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your Google AI API key") |
|
|
model = gr.Dropdown( |
|
|
label="Select Model", |
|
|
choices=[ |
|
|
"gemini-1.5-pro", |
|
|
"gemini-1.5-pro-001", |
|
|
"gemini-1.5-pro-vision-latest", |
|
|
"gemini-1.5-pro-latest", |
|
|
"gemini-1.5-flash", |
|
|
"gemini-1.5-flash-002", |
|
|
"gemini-1.0-pro", |
|
|
"gemini-1.0-pro-001", |
|
|
"gemini-1.0-pro-vision-latest", |
|
|
"gemini-1.0-pro-latest" |
|
|
], |
|
|
value="gemini-1.5-pro", |
|
|
) |
|
|
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter a system prompt (optional)") |
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
fn=bot_streaming, |
|
|
additional_inputs=[ |
|
|
api_key, model, system_prompt, gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature"), |
|
|
gr.Slider(minimum=1, maximum=2048, value=1000, step=1, label="Max Tokens"), gr.Slider(minimum=0, maximum=1, value=0.95, step=0.01, label="Top P"), |
|
|
gr.Slider(minimum=1, maximum=40, value=40, step=1, label="Top K"), |
|
|
gr.Dropdown(label="Harassment", choices=["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE"], value="BLOCK_MEDIUM_AND_ABOVE"), |
|
|
gr.Dropdown(label="Hate Speech", choices=["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE"], value="BLOCK_MEDIUM_AND_ABOVE"), |
|
|
gr.Dropdown(label="Sexually Explicit", choices=["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE"], value="BLOCK_MEDIUM_AND_ABOVE"), |
|
|
gr.Dropdown(label="Dangerous Content", choices=["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE"], value="BLOCK_MEDIUM_AND_ABOVE") |
|
|
], |
|
|
retry_btn="๐ Retry", |
|
|
undo_btn="โฉ๏ธ Undo", |
|
|
clear_btn="๐๏ธ Clear", |
|
|
multimodal=True, |
|
|
cache_examples=False, |
|
|
fill_height=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
# ๐ค Google Gemini API Multimodal Chat |
|
|
|
|
|
Chat with Google Gemini AI models. Supports text and image interactions. |
|
|
|
|
|
## ๐ Quick Start: |
|
|
1. Enter your Google AI API key |
|
|
2. Choose a model |
|
|
3. Start chatting! |
|
|
|
|
|
Enjoy your AI-powered conversation! |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
## ๐ง Settings: |
|
|
- Adjust basic parameters in the "Common Settings" section |
|
|
- Fine-tune safety options in the "Safety Settings" section |
|
|
- Upload images for multimodal interactions |
|
|
""") |
|
|
|
|
|
demo.launch(debug=True, share=True) |
|
|
|