jnkziaa commited on
Commit
7b1499d
·
verified ·
1 Parent(s): 28af0ea

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +97 -0
inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ # Specify the model directory (SageMaker uses /opt/ml/model by default)
6
+ def model_fn(model_dir):
7
+ import os
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+
11
+ def model_fn(model_dir):
12
+ """
13
+ Load the model and tokenizer from the specified directory.
14
+ """
15
+ print("Loading model and tokenizer...")
16
+
17
+ # Load tokenizer
18
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
19
+
20
+ # Explicitly load the .pth file
21
+ model_path = os.path.join(model_dir, "pytorch_model.pth")
22
+ config_path = os.path.join(model_dir, "config.json")
23
+
24
+ if not os.path.exists(model_path):
25
+ raise FileNotFoundError(f"Model file not found: {model_path}")
26
+ if not os.path.exists(config_path):
27
+ raise FileNotFoundError(f"Config file not found: {config_path}")
28
+
29
+ # Load the model using the state dictionary
30
+ config = AutoConfig.from_pretrained(config_path)
31
+ model = AutoModelForSeq2SeqLM(config)
32
+ model.load_state_dict(torch.load(model_path))
33
+ print("Model and tokenizer loaded successfully.")
34
+
35
+ return model, tokenizer
36
+
37
+ def input_fn(serialized_input_data, content_type="application/json"):
38
+ """
39
+ Deserialize the input data from JSON format.
40
+ """
41
+ print("Processing input data...")
42
+ if content_type == "application/json":
43
+ import json
44
+ input_data = json.loads(serialized_input_data)
45
+ if "plsql_code" not in input_data:
46
+ raise ValueError("Missing 'plsql_code' in the input JSON.")
47
+ print("Input data processed successfully.")
48
+ return input_data["plsql_code"]
49
+ else:
50
+ raise ValueError(f"Unsupported content type: {content_type}")
51
+
52
+ def predict_fn(input_data, model_and_tokenizer):
53
+ """
54
+ Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model.
55
+ """
56
+ print("Starting prediction...")
57
+ model, tokenizer = model_and_tokenizer
58
+
59
+ # Construct the tailored prompt
60
+ prompt = f"""
61
+ Translate this PL/SQL function to a Hibernate/JPA-based Java implementation.
62
+
63
+ Requirements:
64
+ 1. Use @Entity, @Table, and @Column annotations to map the database table structure.
65
+ 2. Define Java fields corresponding to the database columns used in the PL/SQL logic.
66
+ 3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer.
67
+ 4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations.
68
+ 5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries.
69
+
70
+ Input PL/SQL:
71
+ {input_data}
72
+ """
73
+ # Tokenize and generate the translated code
74
+ print("Tokenizing input...")
75
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
76
+ print("Generating output...")
77
+ outputs = model.generate(
78
+ inputs["input_ids"],
79
+ max_length=1024,
80
+ num_beams=4,
81
+ early_stopping=True
82
+ )
83
+ translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ print("Prediction completed.")
85
+ return translated_code
86
+
87
+ def output_fn(prediction, accept="application/json"):
88
+ """
89
+ Serialize the prediction result to JSON format.
90
+ """
91
+ print("Serializing output...")
92
+ if accept == "application/json":
93
+ import json
94
+ return json.dumps({"translated_code": prediction}), "application/json"
95
+ else:
96
+ raise ValueError(f"Unsupported accept type: {accept}")
97
+