File size: 3,507 Bytes
ec1fb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os, json, re
from dataclasses import dataclass
from typing import Dict, Any, Optional, Tuple

# OpenAI
try:
    from openai import OpenAI
except Exception:
    OpenAI = None

# Anthropic
try:
    import anthropic
except Exception:
    anthropic = None

class ProviderKind:
    OPENAI = "openai"
    ANTHROPIC = "anthropic"

@dataclass
class Provider:
    kind: str
    model: str
    label: str

    def judge(self, system_prompt: str, user_prompt: str) -> Tuple[Dict[str, Any], Dict[str, int]]:
        raise NotImplementedError

class OpenAIProvider(Provider):
    def __init__(self, model: str):
        super().__init__(ProviderKind.OPENAI, model, f"OpenAI/{model}")
        if OpenAI is None:
            raise RuntimeError("openai package not available. Add to requirements.txt")
        self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

    def judge(self, system_prompt: str, user_prompt: str):
        # Use chat.completions for broader compatibility
        resp = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0
        )
        content = resp.choices[0].message.content
        usage = resp.usage
        usage_dict = {
            "prompt": getattr(usage, "prompt_tokens", 0) or getattr(usage, "input_tokens", 0) or 0,
            "completion": getattr(usage, "completion_tokens", 0) or getattr(usage, "output_tokens", 0) or 0,
            "total": getattr(usage, "total_tokens", 0) or 0,
        }
        return _safe_json(content), usage_dict

class AnthropicProvider(Provider):
    def __init__(self, model: str):
        super().__init__(ProviderKind.ANTHROPIC, model, f"Anthropic/{model}")
        if anthropic is None:
            raise RuntimeError("anthropic package not available. Add to requirements.txt")
        self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))

    def judge(self, system_prompt: str, user_prompt: str):
        resp = self.client.messages.create(
            model=self.model,
            max_tokens=1000,
            temperature=0,
            system=system_prompt,
            messages=[{"role": "user", "content": user_prompt}],
        )
        # Anthropic returns a list of content blocks
        content = "".join([b.text for b in resp.content if hasattr(b, "text")])
        usage = getattr(resp, "usage", None)
        usage_dict = {
            "prompt": getattr(usage, "input_tokens", 0) if usage else 0,
            "completion": getattr(usage, "output_tokens", 0) if usage else 0,
            "total": (getattr(usage, "input_tokens", 0) + getattr(usage, "output_tokens", 0)) if usage else 0,
        }
        return _safe_json(content), usage_dict


def get_provider(kind: str, model: str) -> Provider:
    if kind == ProviderKind.OPENAI:
        return OpenAIProvider(model)
    elif kind == ProviderKind.ANTHROPIC:
        return AnthropicProvider(model)
    else:
        raise ValueError(f"Unknown provider kind: {kind}")

# ----------------------
# Helpers
# ----------------------
JSON_RE = re.compile(r"\{[\s\S]*\}")

def _safe_json(s: str) -> dict:
    try:
        return json.loads(s)
    except Exception:
        m = JSON_RE.search(s or "")
        if m:
            try:
                return json.loads(m.group(0))
            except Exception:
                pass
    return {}