File size: 9,619 Bytes
6965b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import re
import traceback
from langchain import LLMChain, PromptTemplate
from langchain.llms import VertexAI
from libs.logger import logger
import streamlit as st
from google.oauth2 import service_account
from langchain.prompts import ChatPromptTemplate
import libs.general_utils

class VertexAILangChain:
    def __init__(self, project="", location="us-central1", model_name="code-bison", max_tokens=256, temperature:float=0.3, credentials_file_path=None):
        self.project = project
        self.location = location
        self.model_name = model_name
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.credentials_file_path = credentials_file_path
        self.vertexai_llm = None
        self.utils = libs.general_utils.GeneralUtils()

    def load_model(self, model_name, max_tokens, temperature):
        try:
            logger.info(f"Loading model... with project: {self.project} and location: {self.location}")
            # Set the GOOGLE_APPLICATION_CREDENTIALS environment variable
            credentials = service_account.Credentials.from_service_account_file(self.credentials_file_path)

            logger.info(f"Trying to set Vertex model with parameters: {model_name or self.model_name}, {max_tokens or self.max_tokens}, {temperature or self.temperature}, {self.location}")
            self.vertexai_llm = VertexAI(
                model_name=model_name or self.model_name,
                max_output_tokens=max_tokens or self.max_tokens,
                temperature=temperature or self.temperature,
                verbose=True,
                location=self.location,
                credentials=credentials,
            )
            logger.info("Vertex model loaded successfully.")
            return True
        except Exception as exception:
            logger.error(f"Error loading Vertex model: {str(exception)}")
            logger.error(traceback.format_exc())  # Add traceback details
            return False

    def generate_code(self, code_prompt, code_language):
        try:
            # Dynamically construct guidelines based on session state
            guidelines_list = []
            logger.info(f"Generating code with parameters: {code_prompt}, {code_language}")
            
            # Check for empty or null code prompt and code language
            if not code_prompt or len(code_prompt) == 0:
                logger.error("Code prompt is empty or null.")
                st.toast("Code prompt is empty or null.", icon="❌")
                return None
            
            if st.session_state["coding_guidelines"]["modular_code"]:
                logger.info("Modular code is enabled.")
                guidelines_list.append("- Ensure the method is modular in its approach.")
            if st.session_state["coding_guidelines"]["exception_handling"]:
                logger.info("Exception handling is enabled.")
                guidelines_list.append("- Integrate robust exception handling.")
            if st.session_state["coding_guidelines"]["error_handling"]:
                logger.info("Error handling is enabled.")
                guidelines_list.append("- Add error handling to each module.")
            if st.session_state["coding_guidelines"]["efficient_code"]:
                logger.info("Efficient code is enabled.")
                guidelines_list.append("- Optimize the code to ensure it runs efficiently.")
            if st.session_state["coding_guidelines"]["robust_code"]:
                logger.info("Robust code is enabled.")
                guidelines_list.append("- Ensure the code is robust against potential issues.")
            if st.session_state["coding_guidelines"]["naming_conventions"]:
                logger.info("Naming conventions is enabled.")
                guidelines_list.append("- Follow standard naming conventions.")
            
            logger.info("Guidelines: " + str(guidelines_list))

            # Convert the list to a string
            guidelines = "\n".join(guidelines_list)

            # Setting Prompt Template.
            input_section = f"Given the input for code: {st.session_state.code_input}" if st.session_state.code_input else "make sure the program doesn't ask for any input from the user"

            template = f"""
            Task: Design a program {{code_prompt}} in {{code_language}} with the following guidelines and
            make sure the output is printed on the screen.
            And make sure the output contains only the code and nothing else.
            {input_section}

            Guidelines:
            {guidelines}
            """
            
            prompt = PromptTemplate(template=template,input_variables=["code_prompt", "code_language"])
            formatted_prompt = prompt.format(code_prompt=code_prompt, code_language=code_language)
            logger.info(f"Formatted prompt: {formatted_prompt}")
            
            logger.info("Setting up LLMChain...")
            llm_chain = LLMChain(prompt=prompt, llm=self.vertexai_llm)
            logger.info("LLMChain setup successfully.")
            
            # Pass the required inputs as a dictionary to the chain
            logger.info("Running LLMChain...")
            response = llm_chain.run({"code_prompt": code_prompt, "code_language": code_language})
            if response or len(response) > 0:
                logger.info(f"Code generated successfully: {response}")
                
                # Extract text inside code block
                if response.startswith("```") or response.endswith("```"):
                    try:
                        generated_code = re.search('```(.*)```', response, re.DOTALL).group(1)
                    except AttributeError:
                        generated_code = response
                else:
                    st.toast(f"Error extracting code", icon="❌")
                    return response
                    
                if generated_code:
                    # Skip the language name in the first line.
                    response = generated_code.split("\n", 1)[1]
                    logger.info(f"Code generated successfully: {response}")
                else:
                    logger.error(f"Error generating code: {response}")
                    st.toast(f"Error generating code: {response}", icon="❌")
            return response
        except Exception as exception:
            stack_trace = traceback.format_exc()
            logger.error(f"Error generating code: {str(exception)} stack trace: {stack_trace}")
            st.toast(f"Error generating code: {str(exception)} stack trace: {stack_trace}", icon="❌")

    def generate_code_completion(self, code_prompt, code_language):
        try:
            if not code_prompt or len(code_prompt) == 0:
                logger.error("Code prompt is empty or null.")
                st.error("Code generateration cannot be performed as the code prompt is empty or null.")
                return None
            
            logger.info(f"Generating code completion with parameters: {code_prompt}, {code_language}")
            template = f"Complete the following {{code_language}} code: {{code_prompt}}"
            prompt_obj = PromptTemplate(template=template, input_variables=["code_language", "code_prompt"])
            
            max_tokens = st.session_state["vertexai"]["max_tokens"]
            temprature = st.session_state["vertexai"]["temperature"]
            
            # Check the maximum number of tokens of Gecko model i.e 65
            if max_tokens > 65:
                max_tokens = 65
                logger.info(f"Maximum number of tokens for Model Gecko can't exceed 65. Setting max_tokens to 65.")
                st.toast(f"Maximum number of tokens for Model Gecko can't exceed 65. Setting max_tokens to 65.", icon="⚠️")
                
            self.model_name = "code-gecko" # Define the code completion model name.
            self.llm = VertexAI(model_name=self.model_name,max_output_tokens=max_tokens, temperature=temprature)
            logger.info(f"Initialized VertexAI with model: {self.model_name}")
            llm_chain = LLMChain(prompt=prompt_obj, llm=self.llm)
            response = llm_chain.run({"code_prompt": code_prompt, "code_language": code_language})
            
            if response:
                logger.info(f"Code completion generated successfully: {response}")
                return response
            else:
                logger.warning("No response received from LLMChain.")
                return None
        except Exception as e:
            logger.error(f"Error generating code completion: {str(e)}")
            raise

    def set_temperature(self, temperature):
        self.temperature = temperature
        self.vertexai_llm.temperature = temperature
        # call load_model to reload the model with the new temperature and rest values should be same
        self.load_model(self.model_name, self.max_tokens, self.temperature)
        
    def set_max_tokens(self, max_tokens):
        self.max_tokens = max_tokens
        self.vertexai_llm.max_output_tokens = max_tokens
        # call load_model to reload the model with the new max_output_tokens and rest values should be same
        self.load_model(self.model_name, self.max_tokens, self.temperature)
        
    def set_model_name(self, model_name):
        self.model_name = model_name
        # call load_model to reload the model with the new model_name and rest values should be same
        self.load_model(self.model_name, self.max_tokens, self.temperature)