Spaces:
Running
Running
| from abc import ABC, abstractmethod | |
| import ollama | |
| from pydantic import BaseModel | |
| from pathlib import Path | |
| from google import genai | |
| from google.genai import types | |
| from mistralai import Mistral | |
| from groq import Groq | |
| from src.manager.utils.streamlit_interface import output_assistant_response | |
| class AbstractModelManager(ABC): | |
| def __init__(self, model_name, system_prompt_file="system.prompt"): | |
| self.model_name = model_name | |
| script_dir = Path(__file__).parent | |
| self.system_prompt_file = script_dir / system_prompt_file | |
| def is_model_loaded(self, model): | |
| pass | |
| def create_model(self, base_model, context_window=4096, temperature=0): | |
| pass | |
| def request(self, prompt): | |
| pass | |
| def delete(self): | |
| pass | |
| class OllamaModelManager(AbstractModelManager): | |
| def is_model_loaded(self, model): | |
| loaded_models = [m.model for m in ollama.list().models] | |
| return model in loaded_models or f'{model}:latest' in loaded_models | |
| def create_model(self, base_model, context_window=4096, temperature=0): | |
| with open(self.system_prompt_file, 'r') as f: | |
| system = f.read() | |
| if not self.is_model_loaded(self.model_name): | |
| output_assistant_response(f"Creating model {self.model_name}") | |
| ollama.create( | |
| model=self.model_name, | |
| from_=base_model, | |
| system=system, | |
| parameters={ | |
| "num_ctx": context_window, | |
| "temperature": temperature | |
| } | |
| ) | |
| def request(self, prompt): | |
| response = ollama.chat( | |
| model=self.model_name, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| response = response['message']['content'] | |
| return response | |
| def delete(self): | |
| if self.is_model_loaded("C2Rust:latest"): | |
| output_assistant_response(f"Deleting model {self.model_name}") | |
| ollama.delete("C2Rust:latest") | |
| else: | |
| output_assistant_response(f"Model {self.model_name} not found, skipping deletion.") | |
| class GeminiModelManager(AbstractModelManager): | |
| def __init__(self, api_key): | |
| super().__init__() | |
| self.client = genai.Client(api_key=api_key) | |
| self.model = "gemini-2.0-flash" | |
| # read system prompt from file | |
| with open(self.system_prompt_file, 'r') as f: | |
| self.system_instruction = f.read() | |
| def is_model_loaded(self, model): | |
| # Check if the specified model is the one set in the manager | |
| return model == self.model | |
| def create_model(self, base_model=None, context_window=4096, temperature=0): | |
| # Initialize the Gemini model settings (if applicable) | |
| self.model = base_model if base_model else "gemini-2.0-flash" | |
| def request(self, prompt, temperature=0, context_window=4096): | |
| # Request response from the Gemini model | |
| response = self.client.models.generate_content( | |
| model=self.model, | |
| contents=prompt, | |
| config=types.GenerateContentConfig( | |
| temperature=temperature, | |
| max_output_tokens=context_window, | |
| system_instruction=self.system_instruction, | |
| ) | |
| ) | |
| return response.text | |
| def delete(self): | |
| # Implement model deletion logic (if applicable) | |
| self.model = None | |
| class MistralModelManager(AbstractModelManager): | |
| def __init__(self, api_key, model_name="mistral-small-latest", system_prompt_file="system.prompt"): | |
| super().__init__() | |
| self.client = Mistral(api_key=api_key) | |
| self.model = model_name | |
| # read system prompt from file | |
| with open(self.system_prompt_file, 'r') as f: | |
| self.system_instruction = f.read() | |
| def is_model_loaded(self, model): | |
| # Check if the specified model is the one set in the manager | |
| return model == self.model | |
| def create_model(self, base_model=None, context_window=4096, temperature=0): | |
| # Initialize the Mistral model settings (if applicable) | |
| self.model = base_model if base_model else "mistral-small-latest" | |
| def request(self, prompt, temperature=0, context_window=4096): | |
| # Request response from the Mistral model | |
| response = self.client.chat.complete( | |
| messages=[ | |
| { | |
| "role":"user", | |
| "content": self.system_instruction + "\n" + prompt, | |
| } | |
| ], | |
| model=self.model, | |
| temperature=temperature, | |
| max_tokens=context_window, | |
| ) | |
| return response.text | |
| def delete(self): | |
| # Implement model deletion logic (if applicable) | |
| self.model = None | |
| class GroqModelManager(AbstractModelManager): | |
| def __init__(self, api_key, model_name="llama-3.3-70b-versatile", system_prompt_file="system.prompt"): | |
| super().__init__(model_name, system_prompt_file) | |
| self.client = Groq(api_key=api_key) | |
| def is_model_loaded(self, model): | |
| # Groq models are referenced by name; assume always available if name matches | |
| return model == self.model_name | |
| def create_model(self, base_model=None, context_window=4096, temperature=0): | |
| # Groq does not require explicit creation; no-op | |
| if not self.is_model_loaded(self.model_name): | |
| output_assistant_response(f"Model {self.model_name} is not available on Groq.") | |
| def request(self, prompt, temperature=0, context_window=4096): | |
| # Read system instruction | |
| with open(self.system_prompt_file, 'r') as f: | |
| system_instruction = f.read() | |
| # Build messages | |
| messages = [ | |
| {"role": "system", "content": system_instruction}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| # Send request | |
| response = self.client.chat.completions.create( | |
| messages=messages, | |
| model=self.model_name, | |
| temperature=temperature | |
| ) | |
| # Extract and return content | |
| return response.choices[0].message.content | |
| def delete(self): | |
| # No deletion support for Groq-managed models | |
| output_assistant_response(f"Deletion not supported for Groq model {self.model_name}.") | |