Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from transformers import BertTokenizer | |
| from evo_model import EvoTransformerForClassification | |
| from openai import OpenAI | |
| # === Load EvoTransformer === | |
| model = EvoTransformerForClassification.from_pretrained("trained_model") | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # === Tokenizer === | |
| tokenizer = BertTokenizer.from_pretrained("trained_model") | |
| # === GPT-3.5 Client === | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| def query_gpt35(prompt): | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=60, | |
| temperature=0.3, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"[GPT-3.5 Error] {str(e)}" | |
| def get_logits(output): | |
| if isinstance(output, tuple): | |
| return output[1] # (loss, logits) | |
| elif hasattr(output, 'logits'): | |
| return output.logits | |
| else: | |
| return output # raw logits | |
| def generate_response(goal, option1, option2): | |
| try: | |
| # Format input | |
| text1 = f"{goal} [SEP] {option1}" | |
| text2 = f"{goal} [SEP] {option2}" | |
| enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128) | |
| enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128) | |
| # Remove token_type_ids to avoid crash in EvoTransformer | |
| enc1.pop("token_type_ids", None) | |
| enc2.pop("token_type_ids", None) | |
| # Move tensors to device | |
| for k in enc1: | |
| enc1[k] = enc1[k].to(device) | |
| enc2[k] = enc2[k].to(device) | |
| with torch.no_grad(): | |
| out1 = model(**enc1) | |
| out2 = model(**enc2) | |
| logits1 = get_logits(out1) | |
| logits2 = get_logits(out2) | |
| if logits1.shape[-1] < 2 or logits2.shape[-1] < 2: | |
| raise ValueError("Model did not return 2-class logits.") | |
| score1 = logits1[0][1].item() | |
| score2 = logits2[0][1].item() | |
| evo_result = option1 if score1 > score2 else option2 | |
| except Exception as e: | |
| evo_result = f"[Evo Error] {str(e)}" | |
| # GPT-3.5 comparison | |
| prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better and why?" | |
| gpt_result = query_gpt35(prompt) | |
| return { | |
| "evo_suggestion": evo_result, | |
| "gpt_suggestion": gpt_result | |
| } | |