nguyenthanhasia commited on
Commit
57a3905
·
verified ·
1 Parent(s): 6736931

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +112 -0
example_usage.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage of Paraformer model for legal document retrieval.
3
+
4
+ This is a simplified implementation. For full functionality and customization,
5
+ visit: https://github.com/nguyenthanhasia/paraformer
6
+
7
+ License: Research purposes - free to use. Commercial purposes - at your own risk.
8
+ """
9
+
10
+ from transformers import AutoModel
11
+ import torch
12
+
13
+ def main():
14
+ print("Paraformer Model - Example Usage")
15
+ print("=" * 50)
16
+
17
+ # Load the model
18
+ print("Loading model from Hugging Face Hub...")
19
+ model = AutoModel.from_pretrained('nguyenthanhasia/paraformer', trust_remote_code=True)
20
+ print("✓ Model loaded successfully")
21
+
22
+ # Example 1: Single query-article pair
23
+ print("\n1. Single Query-Article Example:")
24
+ print("-" * 30)
25
+
26
+ query = "What are the legal requirements for contract formation?"
27
+ article = [
28
+ "A contract is a legally binding agreement between two or more parties.",
29
+ "For a contract to be valid, it must have offer, acceptance, and consideration.",
30
+ "The parties must have legal capacity to enter into the contract."
31
+ ]
32
+
33
+ print(f"Query: {query}")
34
+ print(f"Article: {len(article)} sentences")
35
+
36
+ # Get relevance score
37
+ relevance_score = model.get_relevance_score(query, article)
38
+ print(f"Relevance Score: {relevance_score:.4f}")
39
+
40
+ # Get binary prediction
41
+ prediction = model.predict_relevance(query, article)
42
+ print(f"Binary Output: {prediction} (0=lower similarity, 1=higher similarity)")
43
+
44
+ # Example 2: Batch processing
45
+ print("\n2. Batch Processing Example:")
46
+ print("-" * 30)
47
+
48
+ queries = [
49
+ "What constitutes a valid contract?",
50
+ "How can employment be terminated?",
51
+ "What are the requirements for copyright protection?"
52
+ ]
53
+
54
+ articles = [
55
+ ["A contract requires offer, acceptance, and consideration.", "All parties must have legal capacity."],
56
+ ["Employment can be terminated by mutual agreement.", "Notice period must be respected."],
57
+ ["Copyright protects original works of authorship.", "The work must be fixed in a tangible medium."]
58
+ ]
59
+
60
+ # Forward pass for batch processing
61
+ outputs = model.forward(
62
+ query_texts=queries,
63
+ article_texts=articles,
64
+ return_dict=True
65
+ )
66
+
67
+ # Get probabilities and predictions
68
+ probabilities = torch.softmax(outputs.logits, dim=-1)
69
+ predictions = torch.argmax(outputs.logits, dim=-1)
70
+
71
+ for i, (query, article) in enumerate(zip(queries, articles)):
72
+ score = probabilities[i, 1].item()
73
+ pred = predictions[i].item()
74
+ print(f"\nQuery {i+1}: {query}")
75
+ print(f" Similarity Score: {score:.4f}")
76
+ print(f" Binary Output: {pred}")
77
+
78
+ # Example 3: Accessing attention weights
79
+ print("\n3. Attention Weights Example:")
80
+ print("-" * 30)
81
+
82
+ query = "What is required for a valid contract?"
83
+ article = [
84
+ "A contract is an agreement between parties.",
85
+ "It must have offer and acceptance.",
86
+ "Consideration is also required.",
87
+ "The weather is nice today." # Irrelevant sentence
88
+ ]
89
+
90
+ outputs = model.forward(
91
+ query_texts=[query],
92
+ article_texts=[article],
93
+ return_dict=True
94
+ )
95
+
96
+ if outputs.attentions is not None:
97
+ attention_weights = outputs.attentions[0, 0] # First batch, first query
98
+ print(f"Query: {query}")
99
+ print("Attention weights per sentence:")
100
+ for i, (sentence, weight) in enumerate(zip(article, attention_weights)):
101
+ print(f" Sentence {i+1}: {weight:.4f} - {sentence}")
102
+
103
+ print("\n" + "=" * 50)
104
+ print("Important Notes:")
105
+ print("- Scores represent similarity in learned feature space, not absolute relevance")
106
+ print("- This is a simplified implementation for easy integration")
107
+ print("- For full functionality: https://github.com/nguyenthanhasia/paraformer")
108
+ print("- Research use: free | Commercial use: at your own risk")
109
+
110
+ if __name__ == "__main__":
111
+ main()
112
+