File size: 1,047 Bytes
00b5849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import streamlit as st
from transformers import pipeline, AutoTokenizer
import time

st.title("Ask Licorn 🦄")

MODELS = {
    "DistilBERT": "aharkane/squad-distilbert-v2",
    "ALBERT":     "aharkane/squad-albert-v2",
    "MobileBERT": "aharkane/squad-mobilebert-v2",
}

@st.cache_resource
def load_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if "distilbert" in model_id.lower():
        tokenizer.model_input_names = ["input_ids", "attention_mask"]
    return pipeline("question-answering", model=model_id, tokenizer=tokenizer)

model_choice = st.selectbox("Choisir un modèle", list(MODELS.keys()))
context = st.text_area("Contexte", height=150)
question = st.text_input("Question")

if st.button("Répondre"):
    qa = load_model(MODELS[model_choice])
    start = time.time()
    result = qa(question=question, context=context)
    duration = time.time() - start
    
    st.success(f"**{result['answer']}**")
    st.write(f"Confiance : {result['score']:.2%}")
    st.write(f"Temps : {duration*1000:.0f} ms")