QNLPDemoApp / utils.py
Yapp99's picture
Another bandaid fix
faf1fab
import numpy as np
from qiskit.primitives.primitive_job import PrimitiveJob
from qiskit import QuantumCircuit
import re
import spacy
import spacy.lang.en
import spacy.cli
import logging
from lambeq import Rewriter, BobcatParser
from tqdm import tqdm
from lambeq import AtomicType, IQPAnsatz
from pytket.extensions.qiskit import tk_to_qiskit
from qiskit_aer.primitives import SamplerV2
from dataclasses import dataclass
from joblib import Parallel, delayed
LANGUAGE_2_SPACY = dict(
en = 'en_core_web_sm',
zh = 'zh_core_web_sm',
)
MAPPING = {
AtomicType.PREPOSITIONAL_PHRASE: 0,
AtomicType.NOUN: 0,
AtomicType.SENTENCE: 1,
AtomicType.CONJUNCTION: 0,
AtomicType.PUNCTUATION: 0,
}
ANSATZ = IQPAnsatz(MAPPING, n_layers=1, discard=True)
LOGGER = logging.getLogger(__name__)
@dataclass
class QNLP_OUTPUT():
tokens: str
circuit: QuantumCircuit
job: PrimitiveJob
@property
def array(self):
if not self.job.done():
return np.array([])
return self.job.result()[0].data.meas.array
@property
def valid(self):
return self.job.done()
class QNLP():
def __init__(self, langauge = "en") -> None:
model = LANGUAGE_2_SPACY.get(langauge)
try:
self.nlp = spacy.load(model)
except OSError:
LOGGER.warning('Downloading SpaCy tokeniser. '
'This action only has to happen once.')
spacy.cli.download(model)
self.nlp = spacy.load(model)
self.rewriter = Rewriter([
'auxiliary',
'connector',
'determiner',
'postadverb',
'preadverb',
'prepositional_phrase',
])
self.parser = BobcatParser('bert', cache_dir='./')
def process_sentence(self,
input_sentence: str,
shots = 1024
) -> list[QNLP_OUTPUT]:
input_sentence = re.sub(r'\n+', '', input_sentence)
docs = self.nlp(input_sentence)
sentences = list([str(s).strip() for s in chunks] for chunks in docs.sents)
def sentence2diagrams(sent: list[str], pb: tqdm = None):
diagram = self.parser.sentence2diagram(sent, tokenised=True)
diagram = self.rewriter(diagram).normal_form()
if pb: pb.update(1)
return diagram
# pb = tqdm(total = len(sentences), desc="Splitting sentences")
# diagrams = Parallel(4, require='sharedmem')(delayed(sentence2diagrams)(s, pb) for s in sentences)
diagrams = list(sentence2diagrams(s) for s in sentences)
qiskit_circuits = list(tk_to_qiskit(ANSATZ(diagram).to_tk()) for diagram in diagrams)
for qc in qiskit_circuits : qc.measure_all()
sampler = SamplerV2()
jobs = [sampler.run([(qc, [1] * qc.num_parameters)], shots=shots)
for qc in qiskit_circuits]
return [QNLP_OUTPUT(*params) for params in zip(sentences, qiskit_circuits, jobs)]