Spaces:
Runtime error
Runtime error
Commit
·
10c8213
1
Parent(s):
add4b49
07/01/23-15:27
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
-
def
|
| 8 |
pattern = r"'''py\n(.*?)'''"
|
| 9 |
match = re.search(pattern, input_text, re.DOTALL)
|
| 10 |
|
|
@@ -12,18 +12,27 @@ def extract_code(input_text):
|
|
| 12 |
return match.group(1)
|
| 13 |
else:
|
| 14 |
return None # Return None if no match is found
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def generate_code(input_text,modelName):
|
| 17 |
if(modelName == "codegen-350M"):
|
| 18 |
input_ids = codeGenTokenizer(input_text, return_tensors="pt").input_ids
|
| 19 |
generated_ids = codeGenModel.generate(input_ids, max_length=128)
|
| 20 |
result = codeGenTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 21 |
-
return
|
| 22 |
elif(modelName == "mistral-7b"):
|
| 23 |
input_ids = mistralTokenizer(generate_prompt_mistral(input_text), return_tensors="pt").input_ids
|
| 24 |
generated_ids = mistralModel.generate(input_ids, max_length=128)
|
| 25 |
result = mistralTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 26 |
-
return result
|
| 27 |
else:
|
| 28 |
return None
|
| 29 |
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
+
def extract_code_codegen(input_text):
|
| 8 |
pattern = r"'''py\n(.*?)'''"
|
| 9 |
match = re.search(pattern, input_text, re.DOTALL)
|
| 10 |
|
|
|
|
| 12 |
return match.group(1)
|
| 13 |
else:
|
| 14 |
return None # Return None if no match is found
|
| 15 |
+
|
| 16 |
+
def extract_code_mistral(input_text):
|
| 17 |
+
pattern = r'\[CODE\](.*?)\[/CODE\]'
|
| 18 |
+
match = re.search(pattern, input_text, re.DOTALL)
|
| 19 |
+
|
| 20 |
+
if match:
|
| 21 |
+
return match.group(1)
|
| 22 |
+
else:
|
| 23 |
+
return None # Return None if no match is found
|
| 24 |
|
| 25 |
def generate_code(input_text,modelName):
|
| 26 |
if(modelName == "codegen-350M"):
|
| 27 |
input_ids = codeGenTokenizer(input_text, return_tensors="pt").input_ids
|
| 28 |
generated_ids = codeGenModel.generate(input_ids, max_length=128)
|
| 29 |
result = codeGenTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 30 |
+
return extract_code_codegen(result)
|
| 31 |
elif(modelName == "mistral-7b"):
|
| 32 |
input_ids = mistralTokenizer(generate_prompt_mistral(input_text), return_tensors="pt").input_ids
|
| 33 |
generated_ids = mistralModel.generate(input_ids, max_length=128)
|
| 34 |
result = mistralTokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 35 |
+
return extract_code_mistral(result)
|
| 36 |
else:
|
| 37 |
return None
|
| 38 |
|