Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
| 1 |
from flask import Flask, render_template, request, jsonify
|
| 2 |
from qdrant_client import QdrantClient
|
| 3 |
from qdrant_client import models
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor
|
| 7 |
-
from transformers import AutoTokenizer, AutoModel
|
| 8 |
from qdrant_client.models import Batch, PointStruct
|
| 9 |
from pickle import load, dump
|
| 10 |
import numpy as np
|
|
@@ -12,7 +8,6 @@ import os, time, sys
|
|
| 12 |
from datetime import datetime as dt
|
| 13 |
from datetime import timedelta
|
| 14 |
from datetime import timezone
|
| 15 |
-
from faster_whisper import WhisperModel
|
| 16 |
import io
|
| 17 |
import requests
|
| 18 |
|
|
@@ -20,8 +15,8 @@ app = Flask(__name__)
|
|
| 20 |
|
| 21 |
# Faster Whisper setup
|
| 22 |
# model_size = 'small'
|
| 23 |
-
beamsize = 2
|
| 24 |
-
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
|
| 25 |
|
| 26 |
# Initialize Qdrant Client and other required settings
|
| 27 |
qdrant_api_key = os.environ.get("qdrant_api_key")
|
|
@@ -29,15 +24,7 @@ qdrant_url = os.environ.get("qdrant_url")
|
|
| 29 |
|
| 30 |
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
|
| 31 |
|
| 32 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
-
|
| 34 |
-
def average_pool(last_hidden_states: Tensor,
|
| 35 |
-
attention_mask: Tensor) -> Tensor:
|
| 36 |
-
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 37 |
-
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 38 |
-
|
| 39 |
-
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
|
| 40 |
-
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
|
| 41 |
|
| 42 |
def e5embed(query):
|
| 43 |
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
|
|
@@ -134,26 +121,6 @@ def delete_joke():
|
|
| 134 |
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
|
| 135 |
return jsonify({"deleted": True})
|
| 136 |
|
| 137 |
-
@app.route("/whisper_transcribe", methods=["POST"])
|
| 138 |
-
def whisper_transcribe():
|
| 139 |
-
if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400
|
| 140 |
-
|
| 141 |
-
audio_file = request.files['audio']
|
| 142 |
-
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
|
| 143 |
-
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400
|
| 144 |
-
|
| 145 |
-
print('Transcribing audio')
|
| 146 |
-
audio_bytes = audio_file.read()
|
| 147 |
-
audio_file = io.BytesIO(audio_bytes)
|
| 148 |
-
|
| 149 |
-
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
|
| 150 |
-
text = ''
|
| 151 |
-
starttime = time.time()
|
| 152 |
-
for segment in segments:
|
| 153 |
-
text += segment.text
|
| 154 |
-
print('Time to transcribe:', time.time() - starttime, 'seconds')
|
| 155 |
-
|
| 156 |
-
return jsonify({'transcription': text})
|
| 157 |
|
| 158 |
|
| 159 |
if __name__ == "__main__":
|
|
|
|
| 1 |
from flask import Flask, render_template, request, jsonify
|
| 2 |
from qdrant_client import QdrantClient
|
| 3 |
from qdrant_client import models
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from qdrant_client.models import Batch, PointStruct
|
| 5 |
from pickle import load, dump
|
| 6 |
import numpy as np
|
|
|
|
| 8 |
from datetime import datetime as dt
|
| 9 |
from datetime import timedelta
|
| 10 |
from datetime import timezone
|
|
|
|
| 11 |
import io
|
| 12 |
import requests
|
| 13 |
|
|
|
|
| 15 |
|
| 16 |
# Faster Whisper setup
|
| 17 |
# model_size = 'small'
|
| 18 |
+
# beamsize = 2
|
| 19 |
+
# wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
|
| 20 |
|
| 21 |
# Initialize Qdrant Client and other required settings
|
| 22 |
qdrant_api_key = os.environ.get("qdrant_api_key")
|
|
|
|
| 24 |
|
| 25 |
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
|
| 26 |
|
| 27 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def e5embed(query):
|
| 30 |
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
|
|
|
|
| 121 |
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
|
| 122 |
return jsonify({"deleted": True})
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
if __name__ == "__main__":
|