Spaces:
Build error
Build error
| import numpy as np | |
| from qiskit.primitives.primitive_job import PrimitiveJob | |
| from qiskit import QuantumCircuit | |
| import re | |
| import spacy | |
| import spacy.lang.en | |
| import spacy.cli | |
| import logging | |
| from lambeq import Rewriter, BobcatParser | |
| from tqdm import tqdm | |
| from lambeq import AtomicType, IQPAnsatz | |
| from pytket.extensions.qiskit import tk_to_qiskit | |
| from qiskit_aer.primitives import SamplerV2 | |
| from dataclasses import dataclass | |
| from joblib import Parallel, delayed | |
| LANGUAGE_2_SPACY = dict( | |
| en = 'en_core_web_sm', | |
| zh = 'zh_core_web_sm', | |
| ) | |
| MAPPING = { | |
| AtomicType.PREPOSITIONAL_PHRASE: 0, | |
| AtomicType.NOUN: 0, | |
| AtomicType.SENTENCE: 1, | |
| AtomicType.CONJUNCTION: 0, | |
| AtomicType.PUNCTUATION: 0, | |
| } | |
| ANSATZ = IQPAnsatz(MAPPING, n_layers=1, discard=True) | |
| LOGGER = logging.getLogger(__name__) | |
| class QNLP_OUTPUT(): | |
| tokens: str | |
| circuit: QuantumCircuit | |
| job: PrimitiveJob | |
| def array(self): | |
| if not self.job.done(): | |
| return np.array([]) | |
| return self.job.result()[0].data.meas.array | |
| def valid(self): | |
| return self.job.done() | |
| class QNLP(): | |
| def __init__(self, langauge = "en") -> None: | |
| model = LANGUAGE_2_SPACY.get(langauge) | |
| try: | |
| self.nlp = spacy.load(model) | |
| except OSError: | |
| LOGGER.warning('Downloading SpaCy tokeniser. ' | |
| 'This action only has to happen once.') | |
| spacy.cli.download(model) | |
| self.nlp = spacy.load(model) | |
| self.rewriter = Rewriter([ | |
| 'auxiliary', | |
| 'connector', | |
| 'determiner', | |
| 'postadverb', | |
| 'preadverb', | |
| 'prepositional_phrase', | |
| ]) | |
| self.parser = BobcatParser('bert', cache_dir='./') | |
| def process_sentence(self, | |
| input_sentence: str, | |
| shots = 1024 | |
| ) -> list[QNLP_OUTPUT]: | |
| input_sentence = re.sub(r'\n+', '', input_sentence) | |
| docs = self.nlp(input_sentence) | |
| sentences = list([str(s).strip() for s in chunks] for chunks in docs.sents) | |
| def sentence2diagrams(sent: list[str], pb: tqdm = None): | |
| diagram = self.parser.sentence2diagram(sent, tokenised=True) | |
| diagram = self.rewriter(diagram).normal_form() | |
| if pb: pb.update(1) | |
| return diagram | |
| # pb = tqdm(total = len(sentences), desc="Splitting sentences") | |
| # diagrams = Parallel(4, require='sharedmem')(delayed(sentence2diagrams)(s, pb) for s in sentences) | |
| diagrams = list(sentence2diagrams(s) for s in sentences) | |
| qiskit_circuits = list(tk_to_qiskit(ANSATZ(diagram).to_tk()) for diagram in diagrams) | |
| for qc in qiskit_circuits : qc.measure_all() | |
| sampler = SamplerV2() | |
| jobs = [sampler.run([(qc, [1] * qc.num_parameters)], shots=shots) | |
| for qc in qiskit_circuits] | |
| return [QNLP_OUTPUT(*params) for params in zip(sentences, qiskit_circuits, jobs)] | |