alaselababatunde commited on
Commit
351bb59
·
1 Parent(s): 22ecd09
Files changed (1) hide show
  1. main.py +7 -15
main.py CHANGED
@@ -7,8 +7,8 @@ from spitch import Spitch
7
  from langchain.prompts import PromptTemplate
8
  from langchain_huggingface import HuggingFaceEndpoint
9
  from langdetect import detect, DetectorFactory
10
- from huggingface_hub.utils import HfHubHTTPError # for quota error handling
11
- from smebuilder_vector import retriever # <-- your retriever
12
 
13
  # ----------------- CONFIG -----------------
14
  DetectorFactory.seed = 0
@@ -25,14 +25,14 @@ if not SPITCH_API_KEY:
25
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
26
  spitch_client = Spitch()
27
 
28
- # HuggingFace LLM (better tuned for code generation)
29
  llm = HuggingFaceEndpoint(
30
  repo_id=HF_MODEL,
31
  temperature=0.7,
32
  top_p=0.9,
33
  do_sample=True,
34
  repetition_penalty=1.1,
35
- max_new_tokens=2048 # bumped tokens
36
  )
37
 
38
  # FastAPI app
@@ -143,8 +143,7 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
143
 
144
  suffix = os.path.splitext(file.filename)[1] or ".wav"
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
146
- content = await file.read()
147
- tf.write(content)
148
  tmp_path = tf.name
149
 
150
  try:
@@ -156,7 +155,6 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
156
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
157
 
158
  transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
159
-
160
  try:
161
  detected_lang = detect(transcription) if transcription.strip() else "en"
162
  except Exception:
@@ -171,7 +169,6 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
171
  translation = transcription
172
 
173
  reply = stt_chain.invoke({"speech": translation})
174
-
175
  return {
176
  "transcription": transcription,
177
  "detected_language": detected_lang,
@@ -189,7 +186,6 @@ def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
189
  async def sme_generate(payload: dict = Body(...)):
190
  try:
191
  user_prompt = payload.get("user_prompt", "")
192
- # retrieve context
193
  context_docs = retriever.get_relevant_documents(user_prompt)
194
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
195
  response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
@@ -205,8 +201,7 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
205
 
206
  suffix = os.path.splitext(file.filename)[1] or ".wav"
207
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
208
- content = await file.read()
209
- tf.write(content)
210
  tmp_path = tf.name
211
 
212
  try:
@@ -218,7 +213,6 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
218
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
219
 
220
  transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
221
-
222
  try:
223
  detected_lang = detect(transcription) if transcription.strip() else "en"
224
  except Exception:
@@ -233,10 +227,8 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
233
  translation = transcription
234
 
235
  try:
236
- # vector retrieval here too
237
  context_docs = retriever.get_relevant_documents(translation)
238
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
239
-
240
  sme_response = sme_chain.invoke({"user_prompt": translation, "context": context})
241
  return {
242
  "success": True,
@@ -250,7 +242,7 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
250
  return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
251
  raise e
252
 
253
- # Hugging Face requires port 7860
254
  if __name__ == "__main__":
255
  import uvicorn
256
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
 
7
  from langchain.prompts import PromptTemplate
8
  from langchain_huggingface import HuggingFaceEndpoint
9
  from langdetect import detect, DetectorFactory
10
+ from huggingface_hub.utils import HfHubHTTPError
11
+ from smebuilder_vector import retriever # your retriever
12
 
13
  # ----------------- CONFIG -----------------
14
  DetectorFactory.seed = 0
 
25
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
26
  spitch_client = Spitch()
27
 
28
+ # HuggingFace LLM
29
  llm = HuggingFaceEndpoint(
30
  repo_id=HF_MODEL,
31
  temperature=0.7,
32
  top_p=0.9,
33
  do_sample=True,
34
  repetition_penalty=1.1,
35
+ max_new_tokens=2048
36
  )
37
 
38
  # FastAPI app
 
143
 
144
  suffix = os.path.splitext(file.filename)[1] or ".wav"
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
146
+ tf.write(await file.read())
 
147
  tmp_path = tf.name
148
 
149
  try:
 
155
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
156
 
157
  transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
 
158
  try:
159
  detected_lang = detect(transcription) if transcription.strip() else "en"
160
  except Exception:
 
169
  translation = transcription
170
 
171
  reply = stt_chain.invoke({"speech": translation})
 
172
  return {
173
  "transcription": transcription,
174
  "detected_language": detected_lang,
 
186
  async def sme_generate(payload: dict = Body(...)):
187
  try:
188
  user_prompt = payload.get("user_prompt", "")
 
189
  context_docs = retriever.get_relevant_documents(user_prompt)
190
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
191
  response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
 
201
 
202
  suffix = os.path.splitext(file.filename)[1] or ".wav"
203
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
204
+ tf.write(await file.read())
 
205
  tmp_path = tf.name
206
 
207
  try:
 
213
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
214
 
215
  transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
 
216
  try:
217
  detected_lang = detect(transcription) if transcription.strip() else "en"
218
  except Exception:
 
227
  translation = transcription
228
 
229
  try:
 
230
  context_docs = retriever.get_relevant_documents(translation)
231
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
 
232
  sme_response = sme_chain.invoke({"user_prompt": translation, "context": context})
233
  return {
234
  "success": True,
 
242
  return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
243
  raise e
244
 
245
+ # ----------------- MAIN -----------------
246
  if __name__ == "__main__":
247
  import uvicorn
248
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)