File size: 5,097 Bytes
eba86cf
 
 
e08f161
eba86cf
 
 
2b26ed4
eba86cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6cb1b4
eba86cf
 
 
 
 
 
 
 
 
a7d0aad
eba86cf
 
 
 
a7d0aad
eba86cf
 
 
89321e2
 
 
 
2b26ed4
a7d0aad
2b26ed4
89321e2
 
2b26ed4
89321e2
2b26ed4
 
 
 
eba86cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f2cbc
 
 
 
 
 
 
 
 
89321e2
99f2cbc
 
 
 
 
 
 
89321e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eba86cf
 
 
 
 
 
 
 
 
148dc3c
 
 
 
 
 
89321e2
 
 
eba86cf
 
89321e2
eba86cf
 
89321e2
eba86cf
89321e2
 
 
 
 
 
 
 
 
 
 
 
 
 
eba86cf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Base class for LLM providers."""

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
import base64
from pathlib import Path
from medrax.utils.utils import load_prompts_from_file


@dataclass
class LLMRequest:
    """Request to an LLM provider."""
    text: str
    images: Optional[List[str]] = None  # List of image paths


@dataclass
class LLMResponse:
    """Response from an LLM provider."""
    content: str
    usage: Optional[Dict[str, Any]] = None
    duration: Optional[float] = None
    chunk_history: Optional[Any] = None
    

class LLMProvider(ABC):
    """Abstract base class for LLM providers.
    
    This class defines the interface for all LLM providers, standardizing
    text + image input -> text output across different models and APIs.
    """

    def __init__(self, model_name: str, system_prompt: str, **kwargs):
        """Initialize the LLM provider.
        
        Args:
            model_name (str): Name of the model to use
            system_prompt (str): System prompt identifier to load from file
            **kwargs: Additional configuration parameters
        """
        self.model_name = model_name
        self.temperature = kwargs.get("temperature", 0.7)
        self.top_p = kwargs.get("top_p", 0.95)
        self.max_tokens = kwargs.get("max_tokens", 5000)
        self.prompt_name = system_prompt
        
        # Load system prompt content from file
        try:
            prompts = load_prompts_from_file("benchmarking/system_prompts.txt")
            self.system_prompt = prompts.get(self.prompt_name, None)
            if self.system_prompt is None:
                print(f"Warning: System prompt '{system_prompt}' not found in benchmarking/system_prompts.txt.")
        except Exception as e:
            print(f"Error loading system prompt: {e}")
            self.system_prompt = None

        self._setup()

    @abstractmethod
    def _setup(self) -> None:
        """Set up the provider (API keys, client initialization, etc.)."""
        pass

    @abstractmethod
    def generate_response(self, request: LLMRequest) -> LLMResponse:
        """Generate a response from the LLM.
        
        Args:
            request (LLMRequest): The request containing text, images, and parameters
            
        Returns:
            LLMResponse: The response from the LLM
        """
        pass

    def test_connection(self) -> bool:
        """Test the connection to the LLM provider.
        
        Returns:
            bool: True if connection is successful, False otherwise
        """
        try:
            # Simple test request
            test_request = LLMRequest(
                text="Hello! What model are you? Tell me your full specification."
            )
            response = self.generate_response(test_request)
            return response.content is not None and len(response.content.strip()) > 0
        except Exception as e:
            print(f"Connection test failed: {e}")
            return False

    def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
        """Validate that image paths exist and are readable.
        
        Args:
            image_paths (List[str]): List of image paths to validate
            
        Returns:
            List[str]: List of valid image paths
        """
        valid_paths = []
        for path in image_paths:
            if Path(path).exists() and Path(path).is_file():
                valid_paths.append(path)
            else:
                print(f"Warning: Image path does not exist: {path}")
        return valid_paths

    def _encode_image(self, image_path: str) -> str:
        """Encode image to base64 string.
        
        Args:
            image_path (str): Path to the image file
            
        Returns:
            str: Base64 encoded image string
        """
        try:
            with open(image_path, "rb") as image_file:
                return base64.b64encode(image_file.read()).decode('utf-8')
        except Exception as e:
            print(f"ERROR: _encode_image failed for {image_path} (type: {type(image_path)}): {e}")
            raise
    
    def _get_image_mime_type(self, image_path: str) -> str:
        """Detect the MIME type of an image file.
        
        Args:
            image_path (str): Path to the image file
            
        Returns:
            str: MIME type (e.g., 'image/png', 'image/jpeg')
        """
        # Get file extension
        ext = Path(image_path).suffix.lower()
        
        # Map extensions to MIME types
        mime_types = {
            '.png': 'image/png',
            '.jpg': 'image/jpeg',
            '.jpeg': 'image/jpeg',
            '.gif': 'image/gif',
            '.webp': 'image/webp',
            '.bmp': 'image/bmp',
        }
        
        return mime_types.get(ext, 'image/png')  # Default to PNG for medical images

    def __str__(self) -> str:
        """String representation of the provider."""
        return f"{self.__class__.__name__}(model={self.model_name})"