Reggie commited on
Commit
5f7526f
·
verified ·
1 Parent(s): 47cf8b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -36
app.py CHANGED
@@ -1,10 +1,6 @@
1
  from flask import Flask, render_template, request, jsonify
2
  from qdrant_client import QdrantClient
3
  from qdrant_client import models
4
- import torch.nn.functional as F
5
- import torch
6
- from torch import Tensor
7
- from transformers import AutoTokenizer, AutoModel
8
  from qdrant_client.models import Batch, PointStruct
9
  from pickle import load, dump
10
  import numpy as np
@@ -12,7 +8,6 @@ import os, time, sys
12
  from datetime import datetime as dt
13
  from datetime import timedelta
14
  from datetime import timezone
15
- from faster_whisper import WhisperModel
16
  import io
17
  import requests
18
 
@@ -20,8 +15,8 @@ app = Flask(__name__)
20
 
21
  # Faster Whisper setup
22
  # model_size = 'small'
23
- beamsize = 2
24
- wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
25
 
26
  # Initialize Qdrant Client and other required settings
27
  qdrant_api_key = os.environ.get("qdrant_api_key")
@@ -29,15 +24,7 @@ qdrant_url = os.environ.get("qdrant_url")
29
 
30
  client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
31
 
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
-
34
- def average_pool(last_hidden_states: Tensor,
35
- attention_mask: Tensor) -> Tensor:
36
- last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
37
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
38
-
39
- tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
40
- model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
41
 
42
  def e5embed(query):
43
  batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
@@ -134,26 +121,6 @@ def delete_joke():
134
  client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
135
  return jsonify({"deleted": True})
136
 
137
- @app.route("/whisper_transcribe", methods=["POST"])
138
- def whisper_transcribe():
139
- if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400
140
-
141
- audio_file = request.files['audio']
142
- allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
143
- if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400
144
-
145
- print('Transcribing audio')
146
- audio_bytes = audio_file.read()
147
- audio_file = io.BytesIO(audio_bytes)
148
-
149
- segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
150
- text = ''
151
- starttime = time.time()
152
- for segment in segments:
153
- text += segment.text
154
- print('Time to transcribe:', time.time() - starttime, 'seconds')
155
-
156
- return jsonify({'transcription': text})
157
 
158
 
159
  if __name__ == "__main__":
 
1
  from flask import Flask, render_template, request, jsonify
2
  from qdrant_client import QdrantClient
3
  from qdrant_client import models
 
 
 
 
4
  from qdrant_client.models import Batch, PointStruct
5
  from pickle import load, dump
6
  import numpy as np
 
8
  from datetime import datetime as dt
9
  from datetime import timedelta
10
  from datetime import timezone
 
11
  import io
12
  import requests
13
 
 
15
 
16
  # Faster Whisper setup
17
  # model_size = 'small'
18
+ # beamsize = 2
19
+ # wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
20
 
21
  # Initialize Qdrant Client and other required settings
22
  qdrant_api_key = os.environ.get("qdrant_api_key")
 
24
 
25
  client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
26
 
27
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
28
 
29
  def e5embed(query):
30
  batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
 
121
  client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
122
  return jsonify({"deleted": True})
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  if __name__ == "__main__":