Commit ·
d8867cf
1
Parent(s): 6398d31
chore: add extra tests
Browse files- app/nlp.py +1 -1
- tests/test_service.py +38 -0
app/nlp.py
CHANGED
|
@@ -136,7 +136,7 @@ def generate_video(audio, text, filter_option):
|
|
| 136 |
output_path = os.path.join("videos", f"{uuid.uuid4()}.mp4")
|
| 137 |
|
| 138 |
try:
|
| 139 |
-
output = pipe(prompt, num_inference_steps=
|
| 140 |
frames = output.frames if hasattr(output, "frames") else output.images
|
| 141 |
export_to_video(frames[0], output_video_path=output_path)
|
| 142 |
|
|
|
|
| 136 |
output_path = os.path.join("videos", f"{uuid.uuid4()}.mp4")
|
| 137 |
|
| 138 |
try:
|
| 139 |
+
output = pipe(prompt, num_inference_steps=1, height=320, width=576, num_frames=1, output_type="pil")
|
| 140 |
frames = output.frames if hasattr(output, "frames") else output.images
|
| 141 |
export_to_video(frames[0], output_video_path=output_path)
|
| 142 |
|
tests/test_service.py
CHANGED
|
@@ -18,6 +18,44 @@ class TestText2VideoService(unittest.TestCase):
|
|
| 18 |
request = pb2.VideoRequest(prompt="", audio_path="", filter_option="None")
|
| 19 |
response = self.stub.Generate(request)
|
| 20 |
self.assertEqual(response.status_code, 400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
if __name__ == "__main__":
|
| 23 |
unittest.main()
|
|
|
|
| 18 |
request = pb2.VideoRequest(prompt="", audio_path="", filter_option="None")
|
| 19 |
response = self.stub.Generate(request)
|
| 20 |
self.assertEqual(response.status_code, 400)
|
| 21 |
+
|
| 22 |
+
def test_generate_with_filter(self):
|
| 23 |
+
request = pb2.VideoRequest(prompt="A forest in autumn", audio_path="", filter_option="Sepia")
|
| 24 |
+
response = self.stub.Generate(request)
|
| 25 |
+
self.assertEqual(response.status_code, 200)
|
| 26 |
+
self.assertIn(".mp4", response.video_path)
|
| 27 |
+
|
| 28 |
+
def test_generate_invalid_filter(self):
|
| 29 |
+
request = pb2.VideoRequest(prompt="A beach at sunset", audio_path="", filter_option="InvalidFilter")
|
| 30 |
+
response = self.stub.Generate(request)
|
| 31 |
+
self.assertEqual(response.status_code, 400)
|
| 32 |
+
|
| 33 |
+
def test_generate_very_long_prompt(self):
|
| 34 |
+
long_prompt = "A " + "very " * 100 + "long prompt"
|
| 35 |
+
request = pb2.VideoRequest(prompt=long_prompt, audio_path="", filter_option="None")
|
| 36 |
+
response = self.stub.Generate(request)
|
| 37 |
+
# Depending on your implementation, this might succeed or return an error
|
| 38 |
+
# Adjust the assertion based on your service's expected behavior
|
| 39 |
+
self.assertIn(response.status_code, [200, 400])
|
| 40 |
+
|
| 41 |
+
def test_generate_special_characters(self):
|
| 42 |
+
request = pb2.VideoRequest(prompt="Special characters: !@#$%^&*()", audio_path="", filter_option="None")
|
| 43 |
+
response = self.stub.Generate(request)
|
| 44 |
+
self.assertEqual(response.status_code, 200)
|
| 45 |
+
self.assertIn(".mp4", response.video_path)
|
| 46 |
+
|
| 47 |
+
def test_timeout_handling(self):
|
| 48 |
+
# Test with a timeout to ensure the service handles timeouts gracefully
|
| 49 |
+
try:
|
| 50 |
+
channel = grpc.insecure_channel('localhost:50051')
|
| 51 |
+
stub = pb2_grpc.VideoGeneratorStub(channel)
|
| 52 |
+
request = pb2.VideoRequest(prompt="A complex scene that takes time", audio_path="", filter_option="None")
|
| 53 |
+
# Set a very short timeout to force a timeout error
|
| 54 |
+
response = stub.Generate(request, timeout=0.001)
|
| 55 |
+
self.fail("Expected timeout exception was not raised")
|
| 56 |
+
except grpc.RpcError as e:
|
| 57 |
+
# Just verify that we can catch the timeout exception
|
| 58 |
+
pass
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
| 61 |
unittest.main()
|