File size: 3,230 Bytes
a1660ff
 
d821914
a1660ff
 
d821914
 
 
 
 
 
 
 
 
a1660ff
 
d821914
5b467a0
a1660ff
5b467a0
 
 
 
 
 
a1660ff
5b467a0
a1660ff
 
 
 
 
 
5b467a0
a1660ff
 
 
 
 
 
 
 
 
 
 
5b467a0
 
a1660ff
5b467a0
 
 
 
 
 
 
a1660ff
 
 
 
5b467a0
 
 
 
 
a1660ff
 
 
 
 
 
5b467a0
 
a1660ff
5b467a0
a1660ff
 
 
 
 
 
 
 
 
 
5b467a0
a1660ff
5b467a0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from src.config.settings import GEMINI_API_KEY, CHATBOT_NAME, MODEL_ID, MODEL_TEMPERATURE, MODEL_OPTIONS, IMAGE_WIDTH, IMAGE_HEIGHT, SYSTEM_INSTRUCTION
from typing import Dict, List, Optional
from PIL import Image
from google import genai
from google.genai import types
import time

def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
    return [seq.strip() for seq in stop_sequences.split(",")] if stop_sequences else None

def preprocess_image(image: Image.Image) -> Image.Image:
    image_height = int(image.height * IMAGE_WIDTH / image.width)
    return image.resize((IMAGE_WIDTH, image_height))

def user(text_prompt: str, chatbot: List[Dict[str, str]]):
    return "", chatbot + [{"role": "user", "content": text_prompt}]

def bot(
    model_name: str,
    image_prompt: Optional[Image.Image],
    temperature: float,
    max_output_tokens: int,
    stop_sequences: str,
    top_k: int,
    top_p: float,
    chatbot: List[Dict[str, str]]
):
    if not GEMINI_API_KEY:
        chatbot.append({"role": "assistant", "content": "GEMINI_API_KEY is not set. Please add it to your .env file."})
        yield chatbot
        return

    client = genai.Client(api_key=GEMINI_API_KEY)

    # Gradio v6 may store content as a list of parts or a plain string
    raw_content = chatbot[-1].get("content") if chatbot else None
    if isinstance(raw_content, list):
        text_prompt = " ".join(
            part.get("text", "") if isinstance(part, dict) else str(part)
            for part in raw_content
        ).strip() or None
    elif isinstance(raw_content, str):
        text_prompt = raw_content.strip() or None
    else:
        text_prompt = None

    if not text_prompt and not image_prompt:
        chatbot.append({"role": "assistant", "content": "Prompt cannot be empty. Please provide input text or an image."})
        yield chatbot
        return
    elif image_prompt and not text_prompt:
        text_prompt = "Describe the image"
    elif image_prompt and text_prompt:
        text_prompt = f"{text_prompt}. Also, analyze the provided image."

    contents = [text_prompt] if image_prompt is None else [text_prompt, preprocess_image(image_prompt)]

    config = types.GenerateContentConfig(
        system_instruction=SYSTEM_INSTRUCTION,
        temperature=temperature,
        max_output_tokens=max_output_tokens,
        stop_sequences=preprocess_stop_sequences(stop_sequences),
        top_k=top_k,
        top_p=top_p,
        safety_settings=[
            types.SafetySetting(
                category="HARM_CATEGORY_DANGEROUS_CONTENT",
                threshold="BLOCK_ONLY_HIGH",
            )
        ],
    )

    chatbot.append({"role": "assistant", "content": ""})
    try:
        for chunk in client.models.generate_content_stream(
            model=model_name,
            contents=contents,
            config=config,
        ):
            if chunk.text:
                for i in range(0, len(chunk.text), 10):
                    chatbot[-1]["content"] += chunk.text[i:i + 10]
                    time.sleep(0.01)
                    yield chatbot
    except Exception as e:
        chatbot[-1]["content"] = f"Error occurred: {str(e)}"
        yield chatbot
        return