lydiasolomon commited on
Commit
f14c1fe
·
verified ·
1 Parent(s): d48566f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -58
main.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import tempfile
3
- from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
 
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
  from spitch import Spitch
@@ -10,13 +12,17 @@ from langdetect import detect, DetectorFactory
10
  from huggingface_hub.utils import HfHubHTTPError
11
  from smebuilder_vector import retriever # Retriever for context injection
12
 
 
 
 
 
13
  # ----------------- CONFIG -----------------
14
  DetectorFactory.seed = 0
15
 
16
  SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
17
  HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct")
18
  FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
19
- PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "super-secret-123") # Default if not set
20
 
21
  if not SPITCH_API_KEY:
22
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
@@ -35,7 +41,31 @@ llm = HuggingFaceEndpoint(
35
  max_new_tokens=2048
36
  )
37
 
38
- # FastAPI app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
40
 
41
  # CORS
@@ -47,7 +77,13 @@ app.add_middleware(
47
  allow_headers=["Authorization", "Content-Type"],
48
  )
49
 
50
- # ----------------- PROMPT TEMPLATES -----------------
 
 
 
 
 
 
51
  chat_template = """You are DevAssist, an AI coding assistant.
52
 
53
  Guidelines:
@@ -103,7 +139,7 @@ stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_templat
103
  autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm
104
  sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm
105
 
106
- # ----------------- REQUEST MODELS -----------------
107
  class ChatRequest(BaseModel):
108
  question: str
109
 
@@ -112,8 +148,7 @@ class AutoDocRequest(BaseModel):
112
 
113
  # ----------------- AUTH -----------------
114
  def check_auth(authorization: str | None):
115
- """Validate Bearer token against PROJECT_API_KEY"""
116
- if not PROJECT_API_KEY: # If not set, skip auth
117
  return
118
  if not authorization or not authorization.startswith("Bearer "):
119
  raise HTTPException(status_code=401, detail="Missing bearer token")
@@ -129,28 +164,18 @@ def root():
129
  @app.post("/chat")
130
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
131
  check_auth(authorization)
132
- try:
133
- answer = chat_chain.invoke({"question": req.question})
134
- return {"reply": answer.strip() if isinstance(answer, str) else str(answer)}
135
- except HfHubHTTPError as e:
136
- if "exceeded" in str(e).lower() or "quota" in str(e).lower():
137
- return {"reply": "⚠️ Daily token limit reached. Try again in 24 hours."}
138
- raise e
139
 
140
  @app.post("/stt")
141
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
142
  check_auth(authorization)
143
-
144
  suffix = os.path.splitext(file.filename)[1] or ".wav"
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
146
  tf.write(await file.read())
147
  tmp_path = tf.name
148
 
149
  try:
150
- if lang_hint:
151
- resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
152
- else:
153
- resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
154
  except Exception:
155
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
156
 
@@ -169,48 +194,34 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
169
  except Exception:
170
  translation = transcription
171
 
172
- reply = stt_chain.invoke({"speech": translation})
173
- return {
174
- "transcription": transcription,
175
- "detected_language": detected_lang,
176
- "translation": translation,
177
- "reply": reply.strip() if isinstance(reply, str) else str(reply)
178
- }
179
 
180
  @app.post("/autodoc")
181
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
182
  check_auth(authorization)
183
- docs = autodoc_chain.invoke({"code": req.code})
184
- return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)}
185
 
186
  @app.post("/sme/generate")
187
  async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
188
  check_auth(authorization)
189
- try:
190
- user_prompt = payload.get("user_prompt", "")
191
- context_docs = retriever.get_relevant_documents(user_prompt)
192
- context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
193
- response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
194
- return {"success": True, "data": response}
195
- except HfHubHTTPError as e:
196
- if "exceeded" in str(e).lower() or "quota" in str(e).lower():
197
- return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
198
- raise e
199
 
200
  @app.post("/sme/speech-generate")
201
  async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
202
  check_auth(authorization)
203
-
204
  suffix = os.path.splitext(file.filename)[1] or ".wav"
205
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
206
  tf.write(await file.read())
207
  tmp_path = tf.name
208
 
209
  try:
210
- if lang_hint:
211
- resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
212
- else:
213
- resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
214
  except Exception:
215
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
216
 
@@ -229,21 +240,10 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
229
  except Exception:
230
  translation = transcription
231
 
232
- try:
233
- context_docs = retriever.get_relevant_documents(translation)
234
- context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
235
- sme_response = sme_chain.invoke({"user_prompt": translation, "context": context})
236
- return {
237
- "success": True,
238
- "transcription": transcription,
239
- "detected_language": detected_lang,
240
- "translation": translation,
241
- "sme_site": sme_response
242
- }
243
- except HfHubHTTPError as e:
244
- if "exceeded" in str(e).lower() or "quota" in str(e).lower():
245
- return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
246
- raise e
247
 
248
  # ----------------- MAIN -----------------
249
  if __name__ == "__main__":
 
1
  import os
2
  import tempfile
3
+ import logging
4
+ from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body, Request
5
+ from fastapi.responses import JSONResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
  from spitch import Spitch
 
12
  from huggingface_hub.utils import HfHubHTTPError
13
  from smebuilder_vector import retriever # Retriever for context injection
14
 
15
+ # ----------------- LOGGING -----------------
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger("DevAssist")
18
+
19
  # ----------------- CONFIG -----------------
20
  DetectorFactory.seed = 0
21
 
22
  SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
23
  HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct")
24
  FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
25
+ PROJECT_API_KEY = os.getenv("PROJECT_API_KEY") # default if not set
26
 
27
  if not SPITCH_API_KEY:
28
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
 
41
  max_new_tokens=2048
42
  )
43
 
44
+ # ----------------- HELPERS -----------------
45
+ def run_llm_model(chain, payload: dict):
46
+ """
47
+ Safely run HuggingFace model through LangChain chain.
48
+ Handles string, dict, and list responses without crashing.
49
+ """
50
+ try:
51
+ result = chain.invoke(payload)
52
+ logger.info(f"HF raw response: {result}")
53
+
54
+ if isinstance(result, str):
55
+ return result.strip()
56
+
57
+ if isinstance(result, dict) and "generated_text" in result:
58
+ return result["generated_text"].strip()
59
+
60
+ if isinstance(result, list) and len(result) > 0 and "generated_text" in result[0]:
61
+ return result[0]["generated_text"].strip()
62
+
63
+ return str(result).strip()
64
+ except Exception as e:
65
+ logger.error(f"LLM execution failed: {e}")
66
+ return f"⚠️ LLM error: {str(e)}"
67
+
68
+ # ----------------- FASTAPI -----------------
69
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
70
 
71
  # CORS
 
77
  allow_headers=["Authorization", "Content-Type"],
78
  )
79
 
80
+ # Global exception handler
81
+ @app.exception_handler(Exception)
82
+ async def global_exception_handler(request: Request, exc: Exception):
83
+ logger.error(f"Unhandled error: {exc}")
84
+ return JSONResponse(status_code=500, content={"error": str(exc)})
85
+
86
+ # ----------------- PROMPTS -----------------
87
  chat_template = """You are DevAssist, an AI coding assistant.
88
 
89
  Guidelines:
 
139
  autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm
140
  sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm
141
 
142
+ # ----------------- MODELS -----------------
143
  class ChatRequest(BaseModel):
144
  question: str
145
 
 
148
 
149
  # ----------------- AUTH -----------------
150
  def check_auth(authorization: str | None):
151
+ if not PROJECT_API_KEY: # if no key set, skip
 
152
  return
153
  if not authorization or not authorization.startswith("Bearer "):
154
  raise HTTPException(status_code=401, detail="Missing bearer token")
 
164
  @app.post("/chat")
165
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
166
  check_auth(authorization)
167
+ return {"reply": run_llm_model(chat_chain, {"question": req.question})}
 
 
 
 
 
 
168
 
169
  @app.post("/stt")
170
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
171
  check_auth(authorization)
 
172
  suffix = os.path.splitext(file.filename)[1] or ".wav"
173
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
174
  tf.write(await file.read())
175
  tmp_path = tf.name
176
 
177
  try:
178
+ resp = spitch_client.speech.transcribe(language=lang_hint or "en", content=open(tmp_path, "rb").read())
 
 
 
179
  except Exception:
180
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
181
 
 
194
  except Exception:
195
  translation = transcription
196
 
197
+ reply = run_llm_model(stt_chain, {"speech": translation})
198
+ return {"transcription": transcription, "detected_language": detected_lang, "translation": translation, "reply": reply}
 
 
 
 
 
199
 
200
  @app.post("/autodoc")
201
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
202
  check_auth(authorization)
203
+ docs = run_llm_model(autodoc_chain, {"code": req.code})
204
+ return {"documentation": docs}
205
 
206
  @app.post("/sme/generate")
207
  async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
208
  check_auth(authorization)
209
+ user_prompt = payload.get("user_prompt", "")
210
+ context_docs = retriever.get_relevant_documents(user_prompt)
211
+ context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
212
+ response = run_llm_model(sme_chain, {"user_prompt": user_prompt, "context": context})
213
+ return {"success": True, "data": response}
 
 
 
 
 
214
 
215
  @app.post("/sme/speech-generate")
216
  async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
217
  check_auth(authorization)
 
218
  suffix = os.path.splitext(file.filename)[1] or ".wav"
219
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
220
  tf.write(await file.read())
221
  tmp_path = tf.name
222
 
223
  try:
224
+ resp = spitch_client.speech.transcribe(language=lang_hint or "en", content=open(tmp_path, "rb").read())
 
 
 
225
  except Exception:
226
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
227
 
 
240
  except Exception:
241
  translation = transcription
242
 
243
+ context_docs = retriever.get_relevant_documents(translation)
244
+ context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
245
+ sme_response = run_llm_model(sme_chain, {"user_prompt": translation, "context": context})
246
+ return {"success": True, "transcription": transcription, "detected_language": detected_lang, "translation": translation, "sme_site": sme_response}
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  # ----------------- MAIN -----------------
249
  if __name__ == "__main__":