Spaces:
Sleeping
Sleeping
distilgpt2 model pipeline..
Browse files
app.py
CHANGED
|
@@ -1,35 +1,47 @@
|
|
| 1 |
from transformers import pipeline
|
|
|
|
| 2 |
|
| 3 |
-
# Initialize the model
|
| 4 |
-
print("Initializing model pipeline...")
|
| 5 |
try:
|
| 6 |
model = pipeline(
|
| 7 |
"text-generation",
|
| 8 |
-
model="
|
| 9 |
-
torch_dtype=torch.float32 # Use float32
|
| 10 |
)
|
| 11 |
-
print("
|
| 12 |
except Exception as e:
|
| 13 |
print(f"Error initializing model: {e}")
|
| 14 |
raise
|
| 15 |
|
| 16 |
def generate_test_cases(requirement):
|
| 17 |
-
#
|
| 18 |
prompt = f"Generate test cases for the following software requirement in JSON format: '{requirement}'. Only provide the JSON array of test cases."
|
| 19 |
|
| 20 |
try:
|
| 21 |
print("Generating test cases...")
|
|
|
|
| 22 |
result = model(prompt, max_length=300, num_return_sequences=1)[0]["generated_text"]
|
| 23 |
print("Test cases generated successfully.")
|
| 24 |
return result.strip()
|
| 25 |
except Exception as e:
|
| 26 |
print(f"Error during generation: {e}")
|
| 27 |
-
|
| 28 |
|
| 29 |
# Example usage
|
| 30 |
if __name__ == "__main__":
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import pipeline
|
| 2 |
+
import torch
|
| 3 |
|
| 4 |
+
# Initialize the smaller GPT-2 model (distilgpt2)
|
| 5 |
+
print("Initializing distilgpt2 model pipeline...")
|
| 6 |
try:
|
| 7 |
model = pipeline(
|
| 8 |
"text-generation",
|
| 9 |
+
model="distilgpt2", # Use the smaller, efficient distilgpt2 model
|
| 10 |
+
torch_dtype=torch.float32 # Use float32 for CPU compatibility
|
| 11 |
)
|
| 12 |
+
print("distilgpt2 model pipeline initialized successfully.")
|
| 13 |
except Exception as e:
|
| 14 |
print(f"Error initializing model: {e}")
|
| 15 |
raise
|
| 16 |
|
| 17 |
def generate_test_cases(requirement):
|
| 18 |
+
# Formulate the prompt to ask the model to generate test cases in JSON format
|
| 19 |
prompt = f"Generate test cases for the following software requirement in JSON format: '{requirement}'. Only provide the JSON array of test cases."
|
| 20 |
|
| 21 |
try:
|
| 22 |
print("Generating test cases...")
|
| 23 |
+
# Generate text (test cases) with a max length of 300 characters
|
| 24 |
result = model(prompt, max_length=300, num_return_sequences=1)[0]["generated_text"]
|
| 25 |
print("Test cases generated successfully.")
|
| 26 |
return result.strip()
|
| 27 |
except Exception as e:
|
| 28 |
print(f"Error during generation: {e}")
|
| 29 |
+
return None
|
| 30 |
|
| 31 |
# Example usage
|
| 32 |
if __name__ == "__main__":
|
| 33 |
+
# Sample requirements for testing
|
| 34 |
+
requirements = [
|
| 35 |
+
"User login functionality with email and password",
|
| 36 |
+
"Search functionality on an e-commerce website",
|
| 37 |
+
"Order placement process for an online store",
|
| 38 |
+
"Admin dashboard to manage users and content"
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
for req in requirements:
|
| 42 |
+
print(f"\nTest cases for requirement: {req}")
|
| 43 |
+
test_cases = generate_test_cases(req)
|
| 44 |
+
if test_cases:
|
| 45 |
+
print(test_cases)
|
| 46 |
+
else:
|
| 47 |
+
print("Failed to generate test cases.")
|