Spaces:
Running
Running
| """ | |
| Infrastructure - Gemini LLM Service | |
| """ | |
| from typing import AsyncIterator, List | |
| import google.generativeai as genai | |
| from app.domain.interfaces import ILLM, LLMMessage, LLMResponse | |
| class GeminiLLM(ILLM): | |
| """Gemini LLM implementation""" | |
| def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"): | |
| genai.configure(api_key=api_key) | |
| self.model_name = model_name | |
| self.model = genai.GenerativeModel(model_name) | |
| async def generate( | |
| self, | |
| messages: List[LLMMessage], | |
| temperature: float = 0.7, | |
| max_tokens: int = 2048, | |
| stream: bool = False, | |
| ) -> LLMResponse: | |
| """Generate response from Gemini""" | |
| # Convert messages to Gemini format | |
| prompt = self._build_prompt(messages) | |
| # Generate | |
| response = await self.model.generate_content_async( | |
| prompt, | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=temperature, max_output_tokens=max_tokens | |
| ), | |
| ) | |
| # Count tokens (approximate) | |
| tokens_used = len(prompt.split()) + len(response.text.split()) | |
| return LLMResponse( | |
| content=response.text, | |
| model=self.model_name, | |
| tokens_used=tokens_used, | |
| finish_reason="stop", | |
| ) | |
| async def generate_stream( | |
| self, | |
| messages: List[LLMMessage], | |
| temperature: float = 0.7, | |
| max_tokens: int = 2048, | |
| ) -> AsyncIterator[str]: | |
| """Generate streaming response from Gemini""" | |
| prompt = self._build_prompt(messages) | |
| response = await self.model.generate_content_async( | |
| prompt, | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=temperature, max_output_tokens=max_tokens | |
| ), | |
| stream=True, | |
| ) | |
| async for chunk in response: | |
| if chunk.text: | |
| yield chunk.text | |
| def get_model_name(self) -> str: | |
| """Get model name""" | |
| return self.model_name | |
| def _build_prompt(self, messages: List[LLMMessage]) -> str: | |
| """Build prompt from messages""" | |
| parts = [] | |
| for msg in messages: | |
| if msg.role == "system": | |
| parts.append(f"System: {msg.content}") | |
| elif msg.role == "user": | |
| parts.append(f"User: {msg.content}") | |
| elif msg.role == "assistant": | |
| parts.append(f"Assistant: {msg.content}") | |
| return "\n\n".join(parts) | |