Update src/pipeline.py
Browse files- src/pipeline.py +5 -2
src/pipeline.py
CHANGED
|
@@ -23,6 +23,7 @@ def error_handler(func: Callable):
|
|
| 23 |
return func(*args, **kwargs)
|
| 24 |
except Exception as e:
|
| 25 |
print(f"Error in {func.__name__}: {str(e)}")
|
|
|
|
| 26 |
return wrapper
|
| 27 |
|
| 28 |
class TorchOptimizer:
|
|
@@ -107,8 +108,10 @@ class PipelineManager:
|
|
| 107 |
pipe.to("cuda")
|
| 108 |
|
| 109 |
# Optimize pipeline
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
# Trigger compilation
|
| 113 |
print("Running torch compilation...")
|
| 114 |
pipe(
|
|
|
|
| 23 |
return func(*args, **kwargs)
|
| 24 |
except Exception as e:
|
| 25 |
print(f"Error in {func.__name__}: {str(e)}")
|
| 26 |
+
return None
|
| 27 |
return wrapper
|
| 28 |
|
| 29 |
class TorchOptimizer:
|
|
|
|
| 108 |
pipe.to("cuda")
|
| 109 |
|
| 110 |
# Optimize pipeline
|
| 111 |
+
pipe_ops = self.optimize_pipeline(pipe)
|
| 112 |
+
if pipe_ops!=None:
|
| 113 |
+
pipe = pipe_ops
|
| 114 |
+
|
| 115 |
# Trigger compilation
|
| 116 |
print("Running torch compilation...")
|
| 117 |
pipe(
|