File size: 707 Bytes
5b6f681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pytest

from src import main


class DummyPipeline:
    def __call__(self, text):
        return [{"label": "POSITIVE", "score": 0.99, "text": text}]


def test_predict_happy_path(monkeypatch):
    # Mock the transformers.pipeline constructor
    monkeypatch.setattr(main, "pipeline", lambda task, model=None: DummyPipeline())

    out = main.predict("Hello world", model_name="dummy-model", task="sentiment-analysis")
    assert out["text"] == "Hello world"
    assert out["model"] == "dummy-model"
    assert out["task"] == "sentiment-analysis"
    assert isinstance(out["result"], list)


def test_predict_type_error():
    with pytest.raises(TypeError):
        main.predict(123)  # type: ignore