Spaces:
Paused
Paused
| from collections.abc import Generator | |
| import google.generativeai.types.generation_types as generation_config_types | |
| import pytest | |
| from _pytest.monkeypatch import MonkeyPatch | |
| from google.ai import generativelanguage as glm | |
| from google.ai.generativelanguage_v1beta.types import content as gag_content | |
| from google.generativeai import GenerativeModel | |
| from google.generativeai.client import _ClientManager, configure | |
| from google.generativeai.types import GenerateContentResponse, content_types, safety_types | |
| from google.generativeai.types.generation_types import BaseGenerateContentResponse | |
| current_api_key = "" | |
| class MockGoogleResponseClass: | |
| _done = False | |
| def __iter__(self): | |
| full_response_text = "it's google!" | |
| for i in range(0, len(full_response_text) + 1, 1): | |
| if i == len(full_response_text): | |
| self._done = True | |
| yield GenerateContentResponse( | |
| done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] | |
| ) | |
| else: | |
| yield GenerateContentResponse( | |
| done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] | |
| ) | |
| class MockGoogleResponseCandidateClass: | |
| finish_reason = "stop" | |
| def content(self) -> gag_content.Content: | |
| return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) | |
| class MockGoogleClass: | |
| def generate_content_sync() -> GenerateContentResponse: | |
| return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) | |
| def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: | |
| return MockGoogleResponseClass() | |
| def generate_content( | |
| self: GenerativeModel, | |
| contents: content_types.ContentsType, | |
| *, | |
| generation_config: generation_config_types.GenerationConfigType | None = None, | |
| safety_settings: safety_types.SafetySettingOptions | None = None, | |
| stream: bool = False, | |
| **kwargs, | |
| ) -> GenerateContentResponse: | |
| global current_api_key | |
| if len(current_api_key) < 16: | |
| raise Exception("Invalid API key") | |
| if stream: | |
| return MockGoogleClass.generate_content_stream() | |
| return MockGoogleClass.generate_content_sync() | |
| def generative_response_text(self) -> str: | |
| return "it's google!" | |
| def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: | |
| return [MockGoogleResponseCandidateClass()] | |
| def make_client(self: _ClientManager, name: str): | |
| global current_api_key | |
| if name.endswith("_async"): | |
| name = name.split("_")[0] | |
| cls = getattr(glm, name.title() + "ServiceAsyncClient") | |
| else: | |
| cls = getattr(glm, name.title() + "ServiceClient") | |
| # Attempt to configure using defaults. | |
| if not self.client_config: | |
| configure() | |
| client_options = self.client_config.get("client_options", None) | |
| if client_options: | |
| current_api_key = client_options.api_key | |
| def nop(self, *args, **kwargs): | |
| pass | |
| original_init = cls.__init__ | |
| cls.__init__ = nop | |
| client: glm.GenerativeServiceClient = cls(**self.client_config) | |
| cls.__init__ = original_init | |
| if not self.default_metadata: | |
| return client | |
| def setup_google_mock(request, monkeypatch: MonkeyPatch): | |
| monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) | |
| monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates) | |
| monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content) | |
| monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client) | |
| yield | |
| monkeypatch.undo() | |