| import unittest |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
|
|
| from quread.heatmap import HeatmapConfig, make_activity_heatmap |
|
|
|
|
| class HeatmapTest(unittest.TestCase): |
| def test_malformed_rows_are_skipped_without_crashing(self): |
| csv_text = "\n".join( |
| [ |
| "step,gate,target,control,theta", |
| "0,H,0,,", |
| "1,CNOT,1,0,", |
| "2,H,not_an_int,,", |
| "3,,1,,", |
| ] |
| ) |
|
|
| fig = make_activity_heatmap(csv_text, n_qubits=2, cfg=HeatmapConfig(rows=2, cols=2)) |
| ax = fig.axes[0] |
| grid = ax.images[0].get_array() |
| labels = [t.get_text() for t in ax.texts] |
|
|
| self.assertEqual(float(grid[0, 0]), 2.0) |
| self.assertEqual(float(grid[0, 1]), 1.0) |
| self.assertIn("Skipped 2 malformed CSV row(s)", labels) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|