File size: 6,668 Bytes
a2e1879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# llm_clients/lmstudio.py
from typing import Generator, Any, Dict
import requests
import json
from .base import LlmClient

class LmstudioClient(LlmClient):
    """LLM client for LM Studio models (OpenAI-compatible API)."""

    def __init__(self, config_dict: Dict[str, Any], system_prompt: str):
        super().__init__(config_dict, system_prompt)
        # LM Studio runs on OpenAI-compatible endpoint
        self.base_url = self.config.get('host', 'http://localhost:1234')
        
        # Test connection to LM Studio
        self._test_connection()
        
        print(f"βœ… LM Studio Client initialized for model '{self.config['model']}' at host '{self.base_url}'.")
        print(f"   Note: LM Studio uses just-in-time loading - model will load on first request.")

    def _test_connection(self):
        """Test connection to LM Studio server."""
        try:
            # Try the models endpoint first (more reliable than health)
            response = requests.get(f"{self.base_url}/v1/models", timeout=5)
            response.raise_for_status()
            
            # Check if our specific model is available
            try:
                models_data = response.json()
                available_models = [model.get('id', '') for model in models_data.get('data', [])]
                
                if available_models:
                    print(f"   πŸ“‹ Available models in LM Studio: {', '.join(available_models)}")
                    if self.config['model'] not in available_models:
                        print(f"   ⚠️  Warning: Model '{self.config['model']}' not found in available models.")
                        print(f"       This is normal with just-in-time loading - model will load on first use.")
                else:
                    print("   πŸ“‹ LM Studio is running with just-in-time model loading.")
                    
            except (json.JSONDecodeError, KeyError):
                print("   πŸ“‹ LM Studio is running (could not parse models list).")
                
        except requests.exceptions.RequestException as e:
            raise ConnectionError(
                f"Could not connect to LM Studio at {self.base_url}. "
                f"Error: {e}\n"
                f"Please ensure:\n"
                f"1. LM Studio is running\n"
                f"2. A model is loaded or just-in-time loading is enabled\n"
                f"3. The server is started (look for 'Server started' in LM Studio console)\n"
                f"4. The correct host/port is configured (default: http://localhost:1234)"
            )

    def generate_content(self, prompt: str) -> str:
        """
        Generates a non-streaming response from LM Studio.
        Uses OpenAI-compatible API format.
        """
        url = f"{self.base_url}/v1/chat/completions"
        
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": prompt}
        ]
        
        payload = {
            "model": self.config['model'],
            "messages": messages,
            "stream": False,
            "temperature": self.config.get('temperature', 0.1),  # Low temperature for security scanning
            "max_tokens": self.config.get('max_tokens', 500)
        }
        
        try:
            response = requests.post(url, json=payload, timeout=30)
            response.raise_for_status()
            
            result = response.json()
            if 'choices' in result and len(result['choices']) > 0:
                return result['choices'][0]['message']['content']
            else:
                raise ValueError(f"Unexpected response format from LM Studio: {result}")
                
        except requests.exceptions.RequestException as e:
            if "404" in str(e):
                raise ConnectionError(
                    f"LM Studio endpoint not found. Please ensure:\n"
                    f"1. LM Studio server is running\n"
                    f"2. A model is loaded (or just-in-time loading is enabled)\n"
                    f"3. The model name '{self.config['model']}' is correct"
                )
            else:
                raise ConnectionError(f"Error communicating with LM Studio: {e}")
        except (json.JSONDecodeError, KeyError, ValueError) as e:
            raise ValueError(f"Error parsing LM Studio response: {e}")

    def generate_content_stream(self, prompt: str) -> Generator[str, None, None]:
        """
        Generates a streaming response from LM Studio.
        Uses OpenAI-compatible API format.
        """
        url = f"{self.base_url}/v1/chat/completions"
        
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": prompt}
        ]
        
        payload = {
            "model": self.config['model'],
            "messages": messages,
            "stream": True,
            "temperature": self.config.get('temperature', 0.7),
            "max_tokens": self.config.get('max_tokens', 2000)
        }
        
        try:
            with requests.post(url, json=payload, stream=True, timeout=30) as response:
                response.raise_for_status()
                
                for line in response.iter_lines():
                    if line:
                        line_str = line.decode('utf-8')
                        if line_str.startswith('data: '):
                            line_str = line_str[6:]  # Remove 'data: ' prefix
                            
                        if line_str.strip() == '[DONE]':
                            break
                            
                        try:
                            chunk = json.loads(line_str)
                            if 'choices' in chunk and len(chunk['choices']) > 0:
                                delta = chunk['choices'][0].get('delta', {})
                                if 'content' in delta:
                                    yield delta['content']
                        except json.JSONDecodeError:
                            continue  # Skip malformed JSON lines
                            
        except requests.exceptions.RequestException as e:
            raise ConnectionError(f"Error during LM Studio streaming: {e}")
    
    def _generate_content_impl(self, prompt: str) -> str:
        """Implementation for base class compatibility."""
        return self.generate_content(prompt)

    def _generate_content_stream_impl(self, prompt: str) -> Generator[str, None, None]:
        """Implementation for base class compatibility."""
        return self.generate_content_stream(prompt)