llm_classifier / app /utils /classification.py
argmin's picture
add files
510a9b0
def apply_classification(client, model_params, ClassificationOutput, system_prompt, user_prompt, verbose=False, st=None):
response = client.chat.completions.create(
model=model_params["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
max_tokens=model_params["max_tokens"],
temperature=model_params["temperature"],
)
raw_prediction = response.choices[0].message.content.strip()
# Log raw prediction for debugging
if verbose and st:
st.info(f"Raw Prediction: {raw_prediction}")
# Validate and process the prediction
try:
validated_prediction = ClassificationOutput.parse_obj({"label": raw_prediction}).label
except Exception as e:
if verbose and st:
st.error(f"Invalid prediction: {raw_prediction}. Error: {e}")
return "INVALID"
return validated_prediction