Spaces:
Running
Running
Kunal Pai
commited on
Commit
·
ffe6e74
1
Parent(s):
2526988
Implement model managers for Ollama, Gemini, and Mistral; update requirements.txt with new dependencies
Browse files- models/llm_models.py +137 -0
- requirements.txt +22 -1
models/llm_models.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import ollama
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from google import genai
|
| 6 |
+
from google.genai import types
|
| 7 |
+
from mistralai import Mistral
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AbstractModelManager(ABC):
|
| 11 |
+
def __init__(self, model_name, system_prompt_file="system.prompt"):
|
| 12 |
+
self.model_name = model_name
|
| 13 |
+
script_dir = Path(__file__).parent
|
| 14 |
+
self.system_prompt_file = script_dir / system_prompt_file
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def is_model_loaded(self, model):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def create_model(self, base_model, context_window=4096, temperature=0):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def request(self, prompt):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def delete(self):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
class OllamaModelManager(AbstractModelManager):
|
| 33 |
+
def is_model_loaded(self, model):
|
| 34 |
+
loaded_models = [m.model for m in ollama.list().models]
|
| 35 |
+
return model in loaded_models or f'{model}:latest' in loaded_models
|
| 36 |
+
|
| 37 |
+
def create_model(self, base_model, context_window=4096, temperature=0):
|
| 38 |
+
with open(self.system_prompt_file, 'r') as f:
|
| 39 |
+
system = f.read()
|
| 40 |
+
|
| 41 |
+
if not self.is_model_loaded(self.model_name):
|
| 42 |
+
print(f"Creating model {self.model_name}")
|
| 43 |
+
ollama.create(
|
| 44 |
+
model=self.model_name,
|
| 45 |
+
from_=base_model,
|
| 46 |
+
system=system,
|
| 47 |
+
parameters={
|
| 48 |
+
"num_ctx": context_window,
|
| 49 |
+
"temperature": temperature
|
| 50 |
+
}
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def request(self, prompt):
|
| 54 |
+
response = ollama.chat(
|
| 55 |
+
model=self.model_name,
|
| 56 |
+
messages=[{"role": "user", "content": prompt}],
|
| 57 |
+
)
|
| 58 |
+
response = response['message']['content']
|
| 59 |
+
return response
|
| 60 |
+
|
| 61 |
+
def delete(self):
|
| 62 |
+
if self.is_model_loaded("C2Rust:latest"):
|
| 63 |
+
print(f"Deleting model {self.model_name}")
|
| 64 |
+
ollama.delete("C2Rust:latest")
|
| 65 |
+
else:
|
| 66 |
+
print(f"Model {self.model_name} not found, skipping deletion.")
|
| 67 |
+
|
| 68 |
+
class GeminiModelManager(AbstractModelManager):
|
| 69 |
+
def __init__(self, api_key):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.client = genai.Client(api_key=api_key)
|
| 72 |
+
self.model = "gemini-2.0-flash"
|
| 73 |
+
# read system prompt from file
|
| 74 |
+
with open(self.system_prompt_file, 'r') as f:
|
| 75 |
+
self.system_instruction = f.read()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def is_model_loaded(self, model):
|
| 79 |
+
# Check if the specified model is the one set in the manager
|
| 80 |
+
return model == self.model
|
| 81 |
+
|
| 82 |
+
def create_model(self, base_model=None, context_window=4096, temperature=0):
|
| 83 |
+
# Initialize the Gemini model settings (if applicable)
|
| 84 |
+
self.model = base_model if base_model else "gemini-2.0-flash"
|
| 85 |
+
|
| 86 |
+
def request(self, prompt, temperature=0, context_window=4096):
|
| 87 |
+
# Request response from the Gemini model
|
| 88 |
+
response = self.client.models.generate_content(
|
| 89 |
+
model=self.model,
|
| 90 |
+
contents=prompt,
|
| 91 |
+
config=types.GenerateContentConfig(
|
| 92 |
+
temperature=temperature,
|
| 93 |
+
max_output_tokens=context_window,
|
| 94 |
+
system_instruction=self.system_instruction,
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
return response.text
|
| 98 |
+
|
| 99 |
+
def delete(self):
|
| 100 |
+
# Implement model deletion logic (if applicable)
|
| 101 |
+
self.model = None
|
| 102 |
+
|
| 103 |
+
class MistralModelManager(AbstractModelManager):
|
| 104 |
+
def __init__(self, api_key, model_name="mistral-small-latest", system_prompt_file="system.prompt"):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.client = Mistral(api_key=api_key)
|
| 107 |
+
self.model = model_name
|
| 108 |
+
# read system prompt from file
|
| 109 |
+
with open(self.system_prompt_file, 'r') as f:
|
| 110 |
+
self.system_instruction = f.read()
|
| 111 |
+
|
| 112 |
+
def is_model_loaded(self, model):
|
| 113 |
+
# Check if the specified model is the one set in the manager
|
| 114 |
+
return model == self.model
|
| 115 |
+
|
| 116 |
+
def create_model(self, base_model=None, context_window=4096, temperature=0):
|
| 117 |
+
# Initialize the Mistral model settings (if applicable)
|
| 118 |
+
self.model = base_model if base_model else "mistral-small-latest"
|
| 119 |
+
|
| 120 |
+
def request(self, prompt, temperature=0, context_window=4096):
|
| 121 |
+
# Request response from the Mistral model
|
| 122 |
+
response = self.client.chat.complete(
|
| 123 |
+
messages=[
|
| 124 |
+
{
|
| 125 |
+
"role":"user",
|
| 126 |
+
"content": self.system_instruction + "\n" + prompt,
|
| 127 |
+
}
|
| 128 |
+
],
|
| 129 |
+
model=self.model,
|
| 130 |
+
temperature=temperature,
|
| 131 |
+
max_tokens=context_window,
|
| 132 |
+
)
|
| 133 |
+
return response.text
|
| 134 |
+
|
| 135 |
+
def delete(self):
|
| 136 |
+
# Implement model deletion logic (if applicable)
|
| 137 |
+
self.model = None
|
requirements.txt
CHANGED
|
@@ -1,19 +1,40 @@
|
|
| 1 |
annotated-types==0.7.0
|
| 2 |
anyio==4.9.0
|
| 3 |
beautifulsoup4==4.13.3
|
|
|
|
| 4 |
certifi==2025.1.31
|
| 5 |
charset-normalizer==3.4.1
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
h11==0.14.0
|
| 8 |
httpcore==1.0.7
|
|
|
|
| 9 |
httpx==0.28.1
|
| 10 |
idna==3.10
|
| 11 |
ollama==0.4.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
pydantic==2.11.1
|
| 13 |
pydantic_core==2.33.0
|
|
|
|
|
|
|
| 14 |
requests==2.32.3
|
|
|
|
| 15 |
sniffio==1.3.1
|
| 16 |
soupsieve==2.6
|
|
|
|
| 17 |
typing-inspection==0.4.0
|
| 18 |
typing_extensions==4.13.0
|
|
|
|
| 19 |
urllib3==2.3.0
|
|
|
|
|
|
| 1 |
annotated-types==0.7.0
|
| 2 |
anyio==4.9.0
|
| 3 |
beautifulsoup4==4.13.3
|
| 4 |
+
cachetools==5.5.2
|
| 5 |
certifi==2025.1.31
|
| 6 |
charset-normalizer==3.4.1
|
| 7 |
+
google-ai-generativelanguage==0.6.15
|
| 8 |
+
google-api-core==2.24.2
|
| 9 |
+
google-api-python-client==2.166.0
|
| 10 |
+
google-auth==2.38.0
|
| 11 |
+
google-auth-httplib2==0.2.0
|
| 12 |
+
google-genai==1.9.0
|
| 13 |
+
googleapis-common-protos==1.69.2
|
| 14 |
+
grpcio==1.71.0
|
| 15 |
+
grpcio-status==1.71.0
|
| 16 |
h11==0.14.0
|
| 17 |
httpcore==1.0.7
|
| 18 |
+
httplib2==0.22.0
|
| 19 |
httpx==0.28.1
|
| 20 |
idna==3.10
|
| 21 |
ollama==0.4.7
|
| 22 |
+
pathlib==1.0.1
|
| 23 |
+
proto-plus==1.26.1
|
| 24 |
+
protobuf==5.29.4
|
| 25 |
+
pyasn1==0.6.1
|
| 26 |
+
pyasn1_modules==0.4.2
|
| 27 |
pydantic==2.11.1
|
| 28 |
pydantic_core==2.33.0
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
python-dotenv==1.1.0
|
| 31 |
requests==2.32.3
|
| 32 |
+
rsa==4.9
|
| 33 |
sniffio==1.3.1
|
| 34 |
soupsieve==2.6
|
| 35 |
+
tqdm==4.67.1
|
| 36 |
typing-inspection==0.4.0
|
| 37 |
typing_extensions==4.13.0
|
| 38 |
+
uritemplate==4.1.1
|
| 39 |
urllib3==2.3.0
|
| 40 |
+
websockets==15.0.1
|