axondendriteplus commited on
Commit
fd1679b
·
verified ·
1 Parent(s): 61f0110

Upload inference_SFT.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference_SFT.py +56 -0
inference_SFT.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ import torch
3
+
4
+ # Load the fine-tuned model
5
+ model, tokenizer = FastLanguageModel.from_pretrained(
6
+ model_name="axondendriteplus/context-relevance-classifier",
7
+ max_seq_length=2048,
8
+ dtype=None,
9
+ load_in_4bit=True,
10
+ )
11
+
12
+ # Enable inference mode
13
+ FastLanguageModel.for_inference(model)
14
+
15
+ def classify_answer(question, answer, context):
16
+ """Classify if answer is generated from context"""
17
+ prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
18
+ You are a context relevance classifier. Given a question, answer, and context, determine if the answer was generated from the given context. Respond with either "YES" if the answer is derived from the context, or "NO" if it is not.
19
+
20
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
21
+ Question: {question}
22
+
23
+ Answer: {answer}
24
+
25
+ Context: {context}
26
+
27
+ Was this answer generated from the given context? Respond with YES or NO only.
28
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
29
+ """
30
+
31
+ inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
32
+
33
+ with torch.no_grad():
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_new_tokens=5,
37
+ use_cache=True,
38
+ do_sample=False,
39
+ repetition_penalty=1.1,
40
+ eos_token_id=tokenizer.eos_token_id,
41
+ )
42
+
43
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+ print(f"Response: {response}")
45
+ prediction = response.split("assistant")[-1].strip()
46
+ print(f"Prediction: {prediction}")
47
+
48
+ return "YES" in prediction.upper()
49
+
50
+ # Test the model
51
+ question = "What is the legal definition of contract?"
52
+ answer = "A contract is a legally binding agreement between two parties."
53
+ context = "Contract law defines a contract as an agreement between two or more parties that creates legally enforceable obligations."
54
+
55
+ result = classify_answer(question, answer, context)
56
+ print(f"result : {result}")