Spaces:
Paused
Paused
| from openai import OpenAI | |
| import anthropic | |
| from together import Together | |
| import os | |
| from atla import Atla | |
| from dotenv import load_dotenv | |
| from .prompts import ( | |
| JUDGE_SYSTEM_PROMPT | |
| ) | |
| from transformers import AutoTokenizer | |
| import requests | |
| import json | |
| import re | |
| load_dotenv() | |
| # Initialize clients | |
| anthropic_client = anthropic.Anthropic() | |
| openai_client = OpenAI() | |
| together_client = Together() | |
| hf_api_key = os.getenv("HF_API_KEY") | |
| atla_client = Atla() | |
| def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0): | |
| """Get response from OpenAI API""" | |
| try: | |
| response = openai_client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_completion_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error with OpenAI model {model_name}: {str(e)}" | |
| def get_anthropic_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0): | |
| """Get response from Anthropic API""" | |
| try: | |
| response = anthropic_client.messages.create( | |
| model=model_name, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| system=system_prompt, | |
| messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], | |
| ) | |
| return response.content[0].text | |
| except Exception as e: | |
| return f"Error with Anthropic model {model_name}: {str(e)}" | |
| def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01): | |
| """Get response from Atla API""" | |
| try: | |
| # Extract components from the prompt data | |
| model_input = prompt.get('human_input', '') | |
| model_output = prompt.get('ai_response', '') | |
| expected_output = prompt.get('ground_truth') | |
| evaluation_criteria = prompt.get('eval_criteria', '') | |
| response = atla_client.evaluation.create( | |
| model_id="atla-selene", | |
| model_input=model_input, | |
| model_output=model_output, | |
| expected_model_output=expected_output if expected_output else None, | |
| evaluation_criteria=evaluation_criteria, | |
| ) | |
| # Return the score and critique directly | |
| return { | |
| "score": response.result.evaluation.score, | |
| "critique": response.result.evaluation.critique | |
| } | |
| except Exception as e: | |
| return f"Error with Atla model {model_name}: {str(e)}" | |
| def get_selene_mini_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01): | |
| """Get response from HF endpoint for Atla model""" | |
| try: | |
| headers = { | |
| "Accept": "application/json", | |
| "Authorization": f"Bearer {hf_api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| # Create messages list for chat template | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| # Apply chat template | |
| model_id = "AtlaAI/Selene-1-Mini-Llama-3.1-8B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key) | |
| formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| payload = { | |
| "inputs": formatted_prompt, | |
| "parameters": { | |
| "max_new_tokens": max_tokens, | |
| "return_full_text": False, | |
| "temperature": temperature, | |
| "seed": 42, | |
| "add_generation_prompt": True | |
| } | |
| } | |
| response = requests.post( | |
| "https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud", | |
| headers=headers, | |
| json=payload | |
| ) | |
| return response.json()[0]["generated_text"] | |
| except Exception as e: | |
| return f"Error with Atla model {model_name}: {str(e)}" | |
| def parse_selene_mini_response(response_text): | |
| """Parse the response from Selene Mini to extract score and critique""" | |
| try: | |
| # Clean up the response text | |
| response_text = response_text.strip() | |
| # More flexible regex patterns | |
| reasoning_pattern = r'\*\*Reasoning:?\*\*\s*(.*?)(?=\*\*Result|$)' | |
| result_pattern = r'\*\*Result:?\*\*\s*(\d+)' | |
| reasoning_match = re.search(reasoning_pattern, response_text, re.DOTALL | re.IGNORECASE) | |
| result_match = re.search(result_pattern, response_text, re.IGNORECASE) | |
| if reasoning_match and result_match: | |
| critique = reasoning_match.group(1).strip() | |
| score = result_match.group(1) | |
| return {"score": score, "critique": critique} | |
| else: | |
| # If we can't parse it properly, let's return the raw response as critique | |
| return { | |
| "score": "Error", | |
| "critique": f"Failed to parse response. Raw response:\n{response_text}" | |
| } | |
| except Exception as e: | |
| return { | |
| "score": "Error", | |
| "critique": f"Error parsing response: {str(e)}\nRaw response:\n{response_text}" | |
| } |