OpenGPT-5.5 / configuration_openai.py
goodgoals's picture
Update configuration_openai.py
8f9b0c3 verified
import openai
import torch
import safetensors.torch
from transformers import AutoTokenizer, AutoModel
# ==== CONFIG ====
openai.api_key = "PUT-YOUR-KEY-HERE!"
MODEL_NAME = "gpt-5.2"
SAFETENSORS_PATH = "model-extended.safetensors" # Your correction vectors
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # CPU embedding model
# ==== LOAD CORRECTION VECTORS ====
# Assume vectors are stored as a single tensor: [num_vectors, embedding_dim]
correction_data = safetensors.torch.load_file(SAFETENSORS_PATH)
# For simplicity, assume the key is 'correction_vectors'
correction_vectors = correction_data["correction_vectors"] # shape: [N, D]
mean_correction = correction_vectors.mean(dim=0)
# ==== LOAD EMBEDDING MODEL (CPU) ====
tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_NAME)
embed_model = AutoModel.from_pretrained(EMBED_MODEL_NAME)
embed_model.eval()
embed_model = embed_model.to("cpu")
def get_embedding(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
out = embed_model(**inputs).last_hidden_state[:,0,:]
return out.squeeze(0) # shape: [embedding_dim]
# ==== FUNCTION TO APPLY CORRECTION ====
def apply_correction(output_text):
output_emb = get_embedding(output_text)
corrected_emb = output_emb + mean_correction # simple additive correction
return corrected_emb
# ==== INFERENCE WITH GPT-5.2 ====
def ask_gpt(prompt):
response = openai.ChatCompletion.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
text = response.choices[0].message.content
# Apply safetensor correction to embeddings if desired
corrected_embedding = apply_correction(text)
# Here you could use corrected_embedding in downstream similarity/re-ranking tasks
return text, corrected_embedding
# ==== TEST ====
if __name__ == "__main__":
prompt = "Who was Thomas Jefferson?"
output, corrected_emb = ask_gpt(prompt)
print("GPT-5.2 Output:", output)
print("Corrected Embedding (first 10 values):", corrected_emb[:10])