| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| |
|
| | |
| | def model_fn(model_dir): |
| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| |
|
| | def model_fn(model_dir): |
| | """ |
| | Load the model and tokenizer from the specified directory. |
| | """ |
| | print("Loading model and tokenizer...") |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| |
|
| | |
| | model_path = os.path.join(model_dir, "pytorch_model.pth") |
| | config_path = os.path.join(model_dir, "config.json") |
| |
|
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"Model file not found: {model_path}") |
| | if not os.path.exists(config_path): |
| | raise FileNotFoundError(f"Config file not found: {config_path}") |
| |
|
| | |
| | config = AutoConfig.from_pretrained(config_path) |
| | model = AutoModelForSeq2SeqLM(config) |
| | model.load_state_dict(torch.load(model_path)) |
| | print("Model and tokenizer loaded successfully.") |
| |
|
| | return model, tokenizer |
| |
|
| | def input_fn(serialized_input_data, content_type="application/json"): |
| | """ |
| | Deserialize the input data from JSON format. |
| | """ |
| | print("Processing input data...") |
| | if content_type == "application/json": |
| | import json |
| | input_data = json.loads(serialized_input_data) |
| | if "plsql_code" not in input_data: |
| | raise ValueError("Missing 'plsql_code' in the input JSON.") |
| | print("Input data processed successfully.") |
| | return input_data["plsql_code"] |
| | else: |
| | raise ValueError(f"Unsupported content type: {content_type}") |
| |
|
| | def predict_fn(input_data, model_and_tokenizer): |
| | """ |
| | Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model. |
| | """ |
| | print("Starting prediction...") |
| | model, tokenizer = model_and_tokenizer |
| |
|
| | |
| | prompt = f""" |
| | Translate this PL/SQL function to a Hibernate/JPA-based Java implementation. |
| | |
| | Requirements: |
| | 1. Use @Entity, @Table, and @Column annotations to map the database table structure. |
| | 2. Define Java fields corresponding to the database columns used in the PL/SQL logic. |
| | 3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer. |
| | 4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations. |
| | 5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries. |
| | |
| | Input PL/SQL: |
| | {input_data} |
| | """ |
| | |
| | print("Tokenizing input...") |
| | inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) |
| | print("Generating output...") |
| | outputs = model.generate( |
| | inputs["input_ids"], |
| | max_length=1024, |
| | num_beams=4, |
| | early_stopping=True |
| | ) |
| | translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | print("Prediction completed.") |
| | return translated_code |
| |
|
| | def output_fn(prediction, accept="application/json"): |
| | """ |
| | Serialize the prediction result to JSON format. |
| | """ |
| | print("Serializing output...") |
| | if accept == "application/json": |
| | import json |
| | return json.dumps({"translated_code": prediction}), "application/json" |
| | else: |
| | raise ValueError(f"Unsupported accept type: {accept}") |
| |
|
| |
|