Spaces:
Configuration error
Configuration error
| from twisted.web.static import File | |
| from twisted.web.resource import Resource | |
| from twisted.web.server import Site, NOT_DONE_YET | |
| from twisted.internet import reactor, threads | |
| from twisted.web._responses import FOUND | |
| import json | |
| import logging | |
| import multiprocessing | |
| import os | |
| import shutil | |
| import uuid | |
| import wave | |
| from gentle.util.paths import get_resource, get_datadir | |
| from gentle.util.cyst import Insist | |
| import gentle | |
| class TranscriptionStatus(Resource): | |
| def __init__(self, status_dict): | |
| self.status_dict = status_dict | |
| Resource.__init__(self) | |
| def render_GET(self, req): | |
| req.setHeader(b"Content-Type", "application/json") | |
| return json.dumps(self.status_dict).encode() | |
| class Transcriber(): | |
| def __init__(self, data_dir, nthreads=4, ntranscriptionthreads=2): | |
| self.data_dir = data_dir | |
| self.nthreads = nthreads | |
| self.ntranscriptionthreads = ntranscriptionthreads | |
| self.resources = gentle.Resources() | |
| self.full_transcriber = gentle.FullTranscriber(self.resources, nthreads=ntranscriptionthreads) | |
| self._status_dicts = {} | |
| def get_status(self, uid): | |
| return self._status_dicts.setdefault(uid, {}) | |
| def out_dir(self, uid): | |
| return os.path.join(self.data_dir, 'transcriptions', uid) | |
| # TODO(maxhawkins): refactor so this is returned by transcribe() | |
| def next_id(self): | |
| uid = None | |
| while uid is None or os.path.exists(os.path.join(self.data_dir, uid)): | |
| uid = uuid.uuid4().hex[:8] | |
| return uid | |
| def transcribe(self, uid, transcript, audio, async_mode, **kwargs): | |
| status = self.get_status(uid) | |
| status['status'] = 'STARTED' | |
| output = { | |
| 'transcript': transcript | |
| } | |
| outdir = os.path.join(self.data_dir, 'transcriptions', uid) | |
| tran_path = os.path.join(outdir, 'transcript.txt') | |
| with open(tran_path, 'w') as tranfile: | |
| tranfile.write(transcript) | |
| audio_path = os.path.join(outdir, 'upload') | |
| with open(audio_path, 'wb') as wavfile: | |
| wavfile.write(audio) | |
| status['status'] = 'ENCODING' | |
| wavfile = os.path.join(outdir, 'a.wav') | |
| if gentle.resample(os.path.join(outdir, 'upload'), wavfile) != 0: | |
| status['status'] = 'ERROR' | |
| status['error'] = "Encoding failed. Make sure that you've uploaded a valid media file." | |
| # Save the status so that errors are recovered on restart of the server | |
| # XXX: This won't work, because the endpoint will override this file | |
| with open(os.path.join(outdir, 'status.json'), 'w') as jsfile: | |
| json.dump(status, jsfile, indent=2) | |
| return | |
| #XXX: Maybe we should pass this wave object instead of the | |
| # file path to align_progress | |
| wav_obj = wave.open(wavfile, 'rb') | |
| status['duration'] = wav_obj.getnframes() / float(wav_obj.getframerate()) | |
| status['status'] = 'TRANSCRIBING' | |
| def on_progress(p): | |
| print(p) | |
| for k,v in p.items(): | |
| status[k] = v | |
| if len(transcript.strip()) > 0: | |
| trans = gentle.ForcedAligner(self.resources, transcript, nthreads=self.nthreads, **kwargs) | |
| elif self.full_transcriber.available: | |
| trans = self.full_transcriber | |
| else: | |
| status['status'] = 'ERROR' | |
| status['error'] = 'No transcript provided and no language model for full transcription' | |
| return | |
| output = trans.transcribe(wavfile, progress_cb=on_progress, logging=logging) | |
| # ...remove the original upload | |
| os.unlink(os.path.join(outdir, 'upload')) | |
| # Save | |
| with open(os.path.join(outdir, 'align.json'), 'w') as jsfile: | |
| jsfile.write(output.to_json(indent=2)) | |
| with open(os.path.join(outdir, 'align.csv'), 'w') as csvfile: | |
| csvfile.write(output.to_csv()) | |
| # Inline the alignment into the index.html file. | |
| htmltxt = open(get_resource('www/view_alignment.html')).read() | |
| htmltxt = htmltxt.replace("var INLINE_JSON;", "var INLINE_JSON=%s;" % (output.to_json())); | |
| open(os.path.join(outdir, 'index.html'), 'w').write(htmltxt) | |
| status['status'] = 'OK' | |
| logging.info('done with transcription.') | |
| return output | |
| class TranscriptionsController(Resource): | |
| def __init__(self, transcriber): | |
| Resource.__init__(self) | |
| self.transcriber = transcriber | |
| def getChild(self, uid, req): | |
| uid = uid.decode() | |
| out_dir = self.transcriber.out_dir(uid) | |
| trans_ctrl = File(out_dir) | |
| # Add a Status endpoint to the file | |
| trans_status = TranscriptionStatus(self.transcriber.get_status(uid)) | |
| trans_ctrl.putChild(b"status.json", trans_status) | |
| return trans_ctrl | |
| def render_POST(self, req): | |
| uid = self.transcriber.next_id() | |
| tran = req.args.get(b'transcript', [b''])[0].decode() | |
| audio = req.args[b'audio'][0] | |
| disfluency = True if b'disfluency' in req.args else False | |
| conservative = True if b'conservative' in req.args else False | |
| kwargs = {'disfluency': disfluency, | |
| 'conservative': conservative, | |
| 'disfluencies': set(['uh', 'um'])} | |
| async_mode = True | |
| if b'async' in req.args and req.args[b'async'][0] == b'false': | |
| async_mode = False | |
| # We need to make the transcription directory here, so that | |
| # when we redirect the user we are sure that there's a place | |
| # for them to go. | |
| outdir = os.path.join(self.transcriber.data_dir, 'transcriptions', uid) | |
| os.makedirs(outdir) | |
| # Copy over the HTML | |
| shutil.copy(get_resource('www/view_alignment.html'), os.path.join(outdir, 'index.html')) | |
| result_promise = threads.deferToThreadPool( | |
| reactor, reactor.getThreadPool(), | |
| self.transcriber.transcribe, | |
| uid, tran, audio, async_mode, **kwargs) | |
| if not async_mode: | |
| def write_result(result): | |
| '''Write JSON to client on completion''' | |
| req.setHeader("Content-Type", "application/json") | |
| req.write(result.to_json(indent=2).encode()) | |
| req.finish() | |
| result_promise.addCallback(write_result) | |
| result_promise.addErrback(lambda _: None) # ignore errors | |
| req.notifyFinish().addErrback(lambda _: result_promise.cancel()) | |
| return NOT_DONE_YET | |
| req.setResponseCode(FOUND) | |
| req.setHeader(b"Location", "/transcriptions/%s" % (uid)) | |
| return b'' | |
| class LazyZipper(Insist): | |
| def __init__(self, cachedir, transcriber, uid): | |
| self.transcriber = transcriber | |
| self.uid = uid | |
| Insist.__init__(self, os.path.join(cachedir, '%s.zip' % (uid))) | |
| def serialize_computation(self, outpath): | |
| shutil.make_archive('.'.join(outpath.split('.')[:-1]), # We need to strip the ".zip" from the end | |
| "zip", # ...because `shutil.make_archive` adds it back | |
| os.path.join(self.transcriber.out_dir(self.uid))) | |
| class TranscriptionZipper(Resource): | |
| def __init__(self, cachedir, transcriber): | |
| self.cachedir = cachedir | |
| self.transcriber = transcriber | |
| Resource.__init__(self) | |
| def getChild(self, path, req): | |
| uid = path.decode().split('.')[0] | |
| t_dir = self.transcriber.out_dir(uid) | |
| if os.path.exists(t_dir): | |
| # TODO: Check that "status" is complete and only create a LazyZipper if so | |
| # Otherwise, we could have incomplete transcriptions that get permanently zipped. | |
| # For now, a solution will be hiding the button in the client until it's done. | |
| lz = LazyZipper(self.cachedir, self.transcriber, uid) | |
| if not isinstance(path, bytes): | |
| path = path.encode() | |
| self.putChild(path, lz) | |
| return lz | |
| else: | |
| return Resource.getChild(self, path, req) | |
| def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, ntranscriptionthreads=2, data_dir=get_datadir('webdata')): | |
| logging.info("SERVE %d, %s, %d", port, interface, installSignalHandlers) | |
| if not os.path.exists(data_dir): | |
| os.makedirs(data_dir) | |
| zip_dir = os.path.join(data_dir, 'zip') | |
| if not os.path.exists(zip_dir): | |
| os.makedirs(zip_dir) | |
| f = File(data_dir) | |
| f.putChild(b'', File(get_resource('www/index.html'))) | |
| f.putChild(b'status.html', File(get_resource('www/status.html'))) | |
| f.putChild(b'preloader.gif', File(get_resource('www/preloader.gif'))) | |
| trans = Transcriber(data_dir, nthreads=nthreads, ntranscriptionthreads=ntranscriptionthreads) | |
| trans_ctrl = TranscriptionsController(trans) | |
| f.putChild(b'transcriptions', trans_ctrl) | |
| trans_zippr = TranscriptionZipper(zip_dir, trans) | |
| f.putChild(b'zip', trans_zippr) | |
| s = Site(f) | |
| logging.info("about to listen") | |
| reactor.listenTCP(port, s, interface=interface) | |
| logging.info("listening") | |
| reactor.run(installSignalHandlers=installSignalHandlers) | |
| if __name__=='__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description='Align a transcript to audio by generating a new language model.') | |
| parser.add_argument('--host', default="0.0.0.0", | |
| help='host to run http server on') | |
| parser.add_argument('--port', default=8765, type=int, | |
| help='port number to run http server on') | |
| parser.add_argument('--nthreads', default=multiprocessing.cpu_count(), type=int, | |
| help='number of alignment threads') | |
| parser.add_argument('--ntranscriptionthreads', default=2, type=int, | |
| help='number of full-transcription threads (memory intensive)') | |
| parser.add_argument('--log', default="INFO", | |
| help='the log level (DEBUG, INFO, WARNING, ERROR, or CRITICAL)') | |
| args = parser.parse_args() | |
| log_level = args.log.upper() | |
| logging.getLogger().setLevel(log_level) | |
| logging.info('gentle %s' % (gentle.__version__)) | |
| logging.info('listening at %s:%d\n' % (args.host, args.port)) | |
| serve(args.port, args.host, nthreads=args.nthreads, ntranscriptionthreads=args.ntranscriptionthreads, installSignalHandlers=1) | |