DevAssist / main.py
alaselababatunde's picture
Updated
c27fb7c
raw
history blame
13.3 kB
import os
import json
import tempfile
import traceback
from typing import Optional
from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from spitch import Spitch
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from langdetect import detect, DetectorFactory
from huggingface_hub.utils import HfHubHTTPError
from smebuilder_vector import retriever # retriever that exposes .get_relevant_documents(...)
DetectorFactory.seed = 0
# ----------------- CONFIG -----------------
SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct")
FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "")
if not SPITCH_API_KEY:
raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
# Init Spitch
os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
spitch_client = Spitch()
# HuggingFace LLM
# NOTE: pass generation params explicitly (pydantic validation requires explicit params)
llm = HuggingFaceEndpoint(
repo_id=HF_MODEL,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.1,
max_new_tokens=2048,
)
# FastAPI app
app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
# ----------------- PROMPT TEMPLATES -----------------
chat_template = """You are DevAssist, an AI coding assistant.
Guidelines:
- Always format responses in Markdown.
- Do NOT use inline formatting like bold (**), italics (*), or underlines.
- Use plain text section headers (e.g., Explanation:, Steps to Fix:, Fixed Code:).
- Use bullet points (-, 1.) for steps.
- Use fenced code blocks (```python ... ```) for code.
- Be friendly yet professional; explain step by step.
Question: {question}
Answer:
"""
stt_chat_template = """You are DevAssist, an AI coding assistant.
- The input is transcribed speech. Interpret it as a dev question.
- Provide clear answers with code examples (use markdown triple backticks).
- If input is unclear, ask a clarifying question.
Spoken Question: {speech}
Answer:
"""
autodoc_template = """You are DevAssist DocBot.
- Read the code and produce professional documentation in markdown.
Code: {code}
Documentation:
"""
sme_template = """
You are a senior full-stack engineer specializing in modern front-end development.
Your job is to generate production-ready code for websites and apps.
Guidelines:
- Always return three separate files: index.html, styles.css, and script.js
- HTML must be semantic, responsive, and mobile-first (include <meta name="viewport">)
- CSS should use Flexbox/Grid and include hover/transition effects
- JavaScript should add interactivity (e.g. button actions, basic animations, toggles)
- Include a hero section, a feature grid, testimonials, and footer
- Use realistic content (avoid lorem ipsum), sensible copy, and accessible markup
- Return only valid JSON with the keys: "files" -> { "index.html": "...", "styles.css": "...", "script.js": "..." }
User Prompt:
{user_prompt}
Context:
{context}
Return:
"""
# ----------------- CHAINS -----------------
chat_chain = PromptTemplate(input_variables=["question"], template=chat_template) | llm
stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_template) | llm
autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm
sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm
# ----------------- REQUEST MODELS -----------------
class ChatRequest(BaseModel):
question: str
class AutoDocRequest(BaseModel):
code: str
# ----------------- AUTH -----------------
def check_auth(authorization: Optional[str] = None):
if not PROJECT_API_KEY:
# No API key enforced in this environment
return
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing bearer token")
token = authorization.split(" ", 1)[1]
if token != PROJECT_API_KEY:
raise HTTPException(status_code=403, detail="Invalid token")
# ----------------- HELPERS -----------------
def try_parse_json(maybe_str: str):
"""Try to parse JSON; if fails, return None."""
try:
return json.loads(maybe_str)
except Exception:
# attempt to find a JSON substring
import re
m = re.search(r"\{[\s\S]*\}\s*$", maybe_str.strip())
if m:
try:
return json.loads(m.group(0))
except Exception:
return None
return None
# ----------------- ENDPOINTS -----------------
@app.get("/")
def root():
return {"status": "DevAssist AI Backend running"}
@app.post("/chat")
def chat(req: ChatRequest, authorization: Optional[str] = Header(None)):
check_auth(authorization)
try:
answer = chat_chain.invoke({"question": req.question})
return {"reply": answer.strip() if isinstance(answer, str) else str(answer)}
except HfHubHTTPError as e:
if "exceeded" in str(e).lower() or "quota" in str(e).lower():
return {"reply": "⚠️ Daily token limit reached. Try again in 24 hours."}
raise e
@app.post("/stt")
async def stt_audio(
file: UploadFile = File(...),
lang_hint: Optional[str] = None,
authorization: Optional[str] = Header(None),
):
check_auth(authorization)
suffix = os.path.splitext(file.filename)[1] or ".wav"
# create temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
content = await file.read()
tf.write(content)
tmp_path = tf.name
try:
# transcribe
if lang_hint:
resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
else:
resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
except Exception:
# fallback to english transcription if something fails
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
try:
detected_lang = detect(transcription) if transcription.strip() else "en"
except Exception:
detected_lang = "en"
translation = transcription
if detected_lang != "en":
try:
translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
translation = getattr(translation_resp, "text", "") or (translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation)
except Exception:
translation = transcription
# call the STT chain (LLM)
try:
reply = stt_chain.invoke({"speech": translation})
except Exception as e:
# on LLM problems return transcription anyway
reply = f"(LLM error) Transcription: {translation}"
# cleanup temp file to avoid storage bloat
try:
os.remove(tmp_path)
except Exception:
pass
return {
"transcription": transcription,
"detected_language": detected_lang,
"translation": translation,
"reply": reply.strip() if isinstance(reply, str) else str(reply),
}
@app.post("/autodoc")
def autodoc(req: AutoDocRequest, authorization: Optional[str] = Header(None)):
check_auth(authorization)
docs = autodoc_chain.invoke({"code": req.code})
return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)}
@app.post("/sme/generate")
async def sme_generate(payload: dict = Body(...), authorization: Optional[str] = Header(None)):
"""
Payload expected: { "user_prompt": "Create ...", (optionally) "force_simple": true }
Returns: success, data (if success) or error
"""
check_auth(authorization)
user_prompt = payload.get("user_prompt", "")
if not user_prompt or not user_prompt.strip():
raise HTTPException(status_code=400, detail="user_prompt is required")
# Get context from retriever (if available)
try:
context_docs = retriever.get_relevant_documents(user_prompt) if retriever else []
context = "\n\n".join([getattr(d, "page_content", str(d)) for d in context_docs]) if context_docs else "No extra context"
except Exception:
context = "No extra context"
# Invoke SME chain
try:
raw = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
# Try to parse returned JSON
parsed = None
if isinstance(raw, str):
parsed = try_parse_json(raw)
elif isinstance(raw, dict):
parsed = raw
if parsed:
return {"success": True, "data": parsed}
else:
# If model didn't return strict JSON, return helpful error + raw output so frontend can show it
return {"success": False, "error": "LLM did not return valid JSON", "raw": raw}
except HfHubHTTPError as e:
if "exceeded" in str(e).lower() or "quota" in str(e).lower():
return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
raise e
except Exception as e:
# Debug info for devs (but don't leak sensitive internals in production)
return {"success": False, "error": "SME generation failed", "details": str(e), "trace": traceback.format_exc()}
@app.post("/sme/speech-generate")
async def sme_speech_generate(
file: UploadFile = File(...),
lang_hint: Optional[str] = None,
authorization: Optional[str] = Header(None),
):
check_auth(authorization)
suffix = os.path.splitext(file.filename)[1] or ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
content = await file.read()
tf.write(content)
tmp_path = tf.name
try:
if lang_hint:
resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
else:
resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
except Exception:
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
try:
detected_lang = detect(transcription) if transcription.strip() else "en"
except Exception:
detected_lang = "en"
translation = transcription
if detected_lang != "en":
try:
translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
translation = getattr(translation_resp, "text", "") or (translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation)
except Exception:
translation = transcription
# Get context docs for the transcribed prompt
try:
context_docs = retriever.get_relevant_documents(translation) if retriever else []
context = "\n\n".join([getattr(d, "page_content", str(d)) for d in context_docs]) if context_docs else "No extra context"
except Exception:
context = "No extra context"
# Invoke SME chain
try:
raw = sme_chain.invoke({"user_prompt": translation, "context": context})
parsed = None
if isinstance(raw, str):
parsed = try_parse_json(raw)
elif isinstance(raw, dict):
parsed = raw
# cleanup tmp file
try:
os.remove(tmp_path)
except Exception:
pass
if parsed:
return {
"success": True,
"transcription": transcription,
"detected_language": detected_lang,
"translation": translation,
"sme_site": parsed,
}
else:
return {
"success": False,
"error": "LLM did not return valid JSON",
"raw": raw,
"transcription": transcription,
"detected_language": detected_lang,
"translation": translation,
}
except HfHubHTTPError as e:
try:
os.remove(tmp_path)
except Exception:
pass
if "exceeded" in str(e).lower() or "quota" in str(e).lower():
return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
raise e
except Exception as e:
try:
os.remove(tmp_path)
except Exception:
pass
return {"success": False, "error": "SME generation failed", "details": str(e), "trace": traceback.format_exc()}
# ----------------- MAIN -----------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)