File size: 965 Bytes
510a9b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def apply_classification(client, model_params, ClassificationOutput, system_prompt, user_prompt, verbose=False, st=None):
    response = client.chat.completions.create(
        model=model_params["model"],
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        max_tokens=model_params["max_tokens"],
        temperature=model_params["temperature"],
    )
    raw_prediction = response.choices[0].message.content.strip()

    # Log raw prediction for debugging
    if verbose and st:
        st.info(f"Raw Prediction: {raw_prediction}")

    # Validate and process the prediction
    try:
        validated_prediction = ClassificationOutput.parse_obj({"label": raw_prediction}).label
    except Exception as e:
        if verbose and st:
            st.error(f"Invalid prediction: {raw_prediction}. Error: {e}")
        return "INVALID"

    return validated_prediction