lydiasolomon commited on
Commit
3c9086c
·
verified ·
1 Parent(s): 9ec1122

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -67
main.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import tempfile
 
3
  from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
@@ -21,11 +22,10 @@ PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
21
  if not SPITCH_API_KEY:
22
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
23
 
24
- # Init Spitch
25
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
26
  spitch_client = Spitch()
27
 
28
- # HuggingFace LLM (removed task="conversational" to avoid StopIteration bug)
29
  llm = HuggingFaceEndpoint(
30
  repo_id=HF_MODEL,
31
  temperature=0.7,
@@ -35,10 +35,9 @@ llm = HuggingFaceEndpoint(
35
  max_new_tokens=2048
36
  )
37
 
38
- # FastAPI app
39
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
40
 
41
- # CORS
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
@@ -111,8 +110,7 @@ class AutoDocRequest(BaseModel):
111
 
112
  # ----------------- AUTH -----------------
113
  def check_auth(authorization: str | None):
114
- """Validate Bearer token against PROJECT_API_KEY"""
115
- if not PROJECT_API_KEY: # If not set, skip auth
116
  return
117
  if not authorization or not authorization.startswith("Bearer "):
118
  raise HTTPException(status_code=401, detail="Missing bearer token")
@@ -120,38 +118,41 @@ def check_auth(authorization: str | None):
120
  if token != PROJECT_API_KEY:
121
  raise HTTPException(status_code=403, detail="Invalid token")
122
 
123
- # ----------------- ENDPOINTS -----------------
124
- @app.get("/")
125
- def root():
126
- return {"status": "✅ DevAssist AI Backend running"}
127
-
128
- @app.post("/chat")
129
- def chat(req: ChatRequest, authorization: str | None = Header(None)):
130
- check_auth(authorization)
131
  try:
132
- answer = chat_chain.invoke({"question": req.question})
133
- return {"reply": answer.strip() if isinstance(answer, str) else str(answer)}
134
- except HfHubHTTPError as e:
135
- if "exceeded" in str(e).lower() or "quota" in str(e).lower():
136
- return {"reply": "⚠️ Daily token limit reached. Try again in 24 hours."}
137
- raise e
138
 
139
- @app.post("/stt")
140
- async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
141
- check_auth(authorization)
 
 
142
 
 
143
  suffix = os.path.splitext(file.filename)[1] or ".wav"
144
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
145
  tf.write(await file.read())
146
  tmp_path = tf.name
147
 
 
 
 
148
  try:
149
  if lang_hint:
150
- resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
151
  else:
152
- resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
153
  except Exception:
154
- resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
155
 
156
  transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
157
  detected_lang = "en"
@@ -168,19 +169,36 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
168
  except Exception:
169
  translation = transcription
170
 
171
- reply = stt_chain.invoke({"speech": translation})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return {
173
  "transcription": transcription,
174
  "detected_language": detected_lang,
175
  "translation": translation,
176
- "reply": reply.strip() if isinstance(reply, str) else str(reply)
177
  }
178
 
179
  @app.post("/autodoc")
180
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
181
  check_auth(authorization)
182
- docs = autodoc_chain.invoke({"code": req.code})
183
- return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)}
184
 
185
  @app.post("/sme/generate")
186
  async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
@@ -189,56 +207,28 @@ async def sme_generate(payload: dict = Body(...), authorization: str | None = He
189
  user_prompt = payload.get("user_prompt", "")
190
  context_docs = retriever.get_relevant_documents(user_prompt)
191
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
192
- response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
193
- return {"success": True, "data": response}
194
- except Exception as e:
195
- return {"success": False, "error": f"⚠️ LLM error: {str(e)}"}
196
 
197
  @app.post("/sme/speech-generate")
198
  async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
199
  check_auth(authorization)
200
-
201
- suffix = os.path.splitext(file.filename)[1] or ".wav"
202
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
203
- tf.write(await file.read())
204
- tmp_path = tf.name
205
-
206
- try:
207
- if lang_hint:
208
- resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
209
- else:
210
- resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
211
- except Exception:
212
- resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
213
-
214
- transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
215
- detected_lang = "en"
216
- try:
217
- detected_lang = detect(transcription) if transcription.strip() else "en"
218
- except Exception:
219
- pass
220
-
221
- translation = transcription
222
- if detected_lang != "en":
223
- try:
224
- translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
225
- translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "")
226
- except Exception:
227
- translation = transcription
228
-
229
  try:
230
  context_docs = retriever.get_relevant_documents(translation)
231
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
232
- sme_response = sme_chain.invoke({"user_prompt": translation, "context": context})
233
  return {
234
  "success": True,
235
  "transcription": transcription,
236
  "detected_language": detected_lang,
237
  "translation": translation,
238
- "sme_site": sme_response
239
  }
240
- except Exception as e:
241
- return {"success": False, "error": f"⚠️ LLM error: {str(e)}"}
242
 
243
  # ----------------- MAIN -----------------
244
  if __name__ == "__main__":
 
1
  import os
2
  import tempfile
3
+ import traceback
4
  from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
 
22
  if not SPITCH_API_KEY:
23
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
24
 
 
25
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
26
  spitch_client = Spitch()
27
 
28
+ # HuggingFace LLM
29
  llm = HuggingFaceEndpoint(
30
  repo_id=HF_MODEL,
31
  temperature=0.7,
 
35
  max_new_tokens=2048
36
  )
37
 
38
+ # ----------------- FASTAPI -----------------
39
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
40
 
 
41
  app.add_middleware(
42
  CORSMiddleware,
43
  allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
 
110
 
111
  # ----------------- AUTH -----------------
112
  def check_auth(authorization: str | None):
113
+ if not PROJECT_API_KEY:
 
114
  return
115
  if not authorization or not authorization.startswith("Bearer "):
116
  raise HTTPException(status_code=401, detail="Missing bearer token")
 
118
  if token != PROJECT_API_KEY:
119
  raise HTTPException(status_code=403, detail="Invalid token")
120
 
121
+ # ----------------- HELPER FUNCTIONS -----------------
122
+ def run_chain(chain, input_dict: dict):
123
+ """
124
+ Safely run a LangChain PromptTemplate | HuggingFaceEndpoint chain.
125
+ Converts output to string and captures errors.
126
+ """
 
 
127
  try:
128
+ # Render template
129
+ if hasattr(chain, "prompt"):
130
+ prompt_text = chain.prompt.format(**input_dict)
131
+ else:
132
+ prompt_text = str(input_dict)
 
133
 
134
+ # Generate using HuggingFaceEndpoint (expects str input)
135
+ output = chain.llm.generate([{"role": "user", "content": prompt_text}])
136
+ return output.generations[0][0].text.strip()
137
+ except Exception:
138
+ return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
139
 
140
+ async def process_audio(file: UploadFile, lang_hint: str | None = None):
141
  suffix = os.path.splitext(file.filename)[1] or ".wav"
142
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
143
  tf.write(await file.read())
144
  tmp_path = tf.name
145
 
146
+ with open(tmp_path, "rb") as f:
147
+ audio_bytes = f.read()
148
+
149
  try:
150
  if lang_hint:
151
+ resp = spitch_client.speech.transcribe(language=lang_hint, content=audio_bytes)
152
  else:
153
+ resp = spitch_client.speech.transcribe(content=audio_bytes)
154
  except Exception:
155
+ resp = spitch_client.speech.transcribe(language="en", content=audio_bytes)
156
 
157
  transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
158
  detected_lang = "en"
 
169
  except Exception:
170
  translation = transcription
171
 
172
+ return transcription, detected_lang, translation
173
+
174
+ # ----------------- ENDPOINTS -----------------
175
+ @app.get("/")
176
+ def root():
177
+ return {"status": "✅ DevAssist AI Backend running"}
178
+
179
+ @app.post("/chat")
180
+ def chat(req: ChatRequest, authorization: str | None = Header(None)):
181
+ check_auth(authorization)
182
+ result = run_chain(chat_chain, {"question": req.question})
183
+ return result if isinstance(result, dict) else {"reply": result}
184
+
185
+ @app.post("/stt")
186
+ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
187
+ check_auth(authorization)
188
+ transcription, detected_lang, translation = await process_audio(file, lang_hint)
189
+ result = run_chain(stt_chain, {"speech": translation})
190
  return {
191
  "transcription": transcription,
192
  "detected_language": detected_lang,
193
  "translation": translation,
194
+ "reply": result if isinstance(result, str) else result.get("reply", "")
195
  }
196
 
197
  @app.post("/autodoc")
198
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
199
  check_auth(authorization)
200
+ result = run_chain(autodoc_chain, {"code": req.code})
201
+ return result if isinstance(result, dict) else {"documentation": result}
202
 
203
  @app.post("/sme/generate")
204
  async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
 
207
  user_prompt = payload.get("user_prompt", "")
208
  context_docs = retriever.get_relevant_documents(user_prompt)
209
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
210
+ result = run_chain(sme_chain, {"user_prompt": user_prompt, "context": context})
211
+ return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
212
+ except Exception:
213
+ return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
214
 
215
  @app.post("/sme/speech-generate")
216
  async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
217
  check_auth(authorization)
218
+ transcription, detected_lang, translation = await process_audio(file, lang_hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  try:
220
  context_docs = retriever.get_relevant_documents(translation)
221
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
222
+ result = run_chain(sme_chain, {"user_prompt": translation, "context": context})
223
  return {
224
  "success": True,
225
  "transcription": transcription,
226
  "detected_language": detected_lang,
227
  "translation": translation,
228
+ "sme_site": result if isinstance(result, str) else result.get("reply", "")
229
  }
230
+ except Exception:
231
+ return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
232
 
233
  # ----------------- MAIN -----------------
234
  if __name__ == "__main__":