| from unsloth import FastLanguageModel |
| import torch |
|
|
| |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name="axondendriteplus/context-relevance-classifier", |
| max_seq_length=2048, |
| dtype=None, |
| load_in_4bit=True, |
| ) |
|
|
| |
| FastLanguageModel.for_inference(model) |
|
|
| def classify_answer(question, answer, context): |
| """Classify if answer is generated from context""" |
| prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not. |
| |
| <|eot_id|><|start_header_id|>user<|end_header_id|> |
| Question: {question} |
| |
| Answer: {answer} |
| |
| Context: {context} |
| |
| Was this answer generated from the given context? Respond with YES or NO only. |
| <|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| """ |
| |
| inputs = tokenizer([prompt], return_tensors="pt").to("cuda") |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=5, |
| use_cache=True, |
| do_sample=False, |
| repetition_penalty=1.1, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| print(f"Response: {response}") |
| prediction = response.split("assistant")[-1].strip() |
| print(f"Prediction: {prediction}") |
| |
| return "YES" in prediction.upper() |
|
|
| |
| question = "What is the legal definition of contract?" |
| answer = "A contract is a legally binding agreement between two parties." |
| context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations." |
|
|
| result = classify_answer(question, answer, context) |
| print(f"result : {result}") |