chips commited on
Commit
4645b0f
·
1 Parent(s): 0e18f46

gemini added as service

Browse files
Files changed (1) hide show
  1. app/services/factory.py +31 -5
app/services/factory.py CHANGED
@@ -2,19 +2,45 @@ from typing import Type
2
 
3
  from ..config import get_settings
4
  from .base import BaseAttributionService
5
- from .service_anthropic import AnthropicService
6
  from .service_openai import OpenAIService
 
7
 
8
  settings = get_settings()
9
 
10
 
11
  class AIServiceFactory:
12
- _services = {"openai": OpenAIService, "anthropic": AnthropicService}
 
 
 
 
13
 
14
  @classmethod
15
  def get_service(cls, ai_vendor: str = None) -> BaseAttributionService:
16
- ai_vendor = ai_vendor or settings.DEFAULT_VENDOR
17
- service_class = cls._services.get(ai_vendor.lower())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if not service_class:
19
- raise ValueError(f"Unsupported ai_vendor: {ai_vendor}")
 
 
20
  return service_class()
 
 
2
 
3
  from ..config import get_settings
4
  from .base import BaseAttributionService
5
+ from .service_anthropic import AnthropicService # Assuming this exists
6
  from .service_openai import OpenAIService
7
+ from .service_gemini import GeminiService # Import the new GeminiService
8
 
9
  settings = get_settings()
10
 
11
 
12
  class AIServiceFactory:
13
+ _services = {
14
+ "openai": OpenAIService,
15
+ "anthropic": AnthropicService,
16
+ "gemini": GeminiService, # Add Gemini service
17
+ }
18
 
19
  @classmethod
20
  def get_service(cls, ai_vendor: str = None) -> BaseAttributionService:
21
+ """
22
+ Factory method to get an AI service instance.
23
+
24
+ Args:
25
+ ai_vendor (str, optional): The name of the AI vendor.
26
+ Defaults to settings.DEFAULT_VENDOR.
27
+
28
+ Returns:
29
+ BaseAttributionService: An instance of the requested AI service.
30
+
31
+ Raises:
32
+ ValueError: If the ai_vendor is unsupported.
33
+ """
34
+ # Use the provided ai_vendor or fallback to the default from settings
35
+ ai_vendor_to_use = ai_vendor or settings.DEFAULT_VENDOR
36
+
37
+ # Retrieve the service class from the dictionary (case-insensitive lookup)
38
+ service_class = cls._services.get(ai_vendor_to_use.lower())
39
+
40
+ # If the service class is not found, raise an error
41
  if not service_class:
42
+ raise ValueError(f"Unsupported AI vendor: {ai_vendor_to_use}. Supported vendors are: {', '.join(cls._services.keys())}")
43
+
44
+ # Create and return an instance of the service
45
  return service_class()
46
+