MohitGupta41 commited on
Commit ·
62960aa
1
Parent(s): 8b9d771
Increased Context Window and improved prompt
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel, Field, ConfigDict
|
| 10 |
import httpx
|
|
|
|
| 11 |
from datetime import date
|
| 12 |
|
| 13 |
from Constants import CONTEXT
|
|
@@ -120,46 +121,62 @@ async def call_gemini(
|
|
| 120 |
# If we got here, all attempts failed
|
| 121 |
raise last_err or HTTPException(502, "Gemini request failed")
|
| 122 |
|
| 123 |
-
async def call_huggingface_inference(
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
"""
|
| 130 |
-
|
| 131 |
"""
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
"
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
async with httpx.AsyncClient(timeout=120) as client:
|
| 144 |
-
r = await client.post(url, headers=headers, json=payload)
|
| 145 |
-
|
| 146 |
-
if r.status_code == 200:
|
| 147 |
-
data = r.json()
|
| 148 |
-
# HF returns either a list[{"generated_text": "..."}] or a dict with error/stream info
|
| 149 |
-
if isinstance(data, list) and data and "generated_text" in data[0]:
|
| 150 |
-
return data[0]["generated_text"].strip()
|
| 151 |
-
# Some pipelines return dict with "generated_text"
|
| 152 |
-
if isinstance(data, dict) and "generated_text" in data:
|
| 153 |
-
return data["generated_text"].strip()
|
| 154 |
-
# Some models return plain string
|
| 155 |
-
if isinstance(data, str):
|
| 156 |
-
return data.strip()
|
| 157 |
-
raise HTTPException(502, f"Unexpected HF response format: {data}")
|
| 158 |
-
elif r.status_code == 503:
|
| 159 |
-
# Model is loading or warming up
|
| 160 |
-
raise HTTPException(503, "Hugging Face model is loading. Please retry.")
|
| 161 |
-
else:
|
| 162 |
-
raise HTTPException(r.status_code, f"Hugging Face error: {r.text}")
|
| 163 |
|
| 164 |
# ---------- FastAPI ----------
|
| 165 |
app = FastAPI(title="Voice Agent API", version="0.2.0")
|
|
@@ -238,19 +255,28 @@ async def chat(
|
|
| 238 |
text = await call_gemini(gemini_key, model, prompt)
|
| 239 |
return ChatOut(answer=text or "Sorry, I didn't catch that.")
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
elif provider == "huggingface":
|
| 242 |
model = payload.model or os.getenv("DEFAULT_HF_MODEL", "google/gemma-3-27b-it")
|
| 243 |
-
|
| 244 |
-
hf_key = payload.hf_api_key or x_hf_api_key
|
| 245 |
-
if not hf_key and authorization and authorization.lower().startswith("bearer "):
|
| 246 |
-
hf_key = authorization.split(" ", 1)[1].strip()
|
| 247 |
-
if not hf_key:
|
| 248 |
-
hf_key = os.getenv("HF_API_KEY")
|
| 249 |
if not hf_key:
|
| 250 |
raise HTTPException(400, "Hugging Face API key is required (send hf_api_key, X-Hf-Api-Key, or Authorization: Bearer).")
|
| 251 |
-
text = await call_huggingface_inference(hf_key, model, prompt)
|
| 252 |
-
return ChatOut(answer=text or "Sorry, I didn't catch that.")
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
else:
|
| 255 |
raise HTTPException(400, f"Unknown provider: {provider}")
|
| 256 |
|
|
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel, Field, ConfigDict
|
| 10 |
import httpx
|
| 11 |
+
from huggingface_hub import InferenceClient
|
| 12 |
from datetime import date
|
| 13 |
|
| 14 |
from Constants import CONTEXT
|
|
|
|
| 121 |
# If we got here, all attempts failed
|
| 122 |
raise last_err or HTTPException(502, "Gemini request failed")
|
| 123 |
|
| 124 |
+
# async def call_huggingface_inference(
|
| 125 |
+
# hf_api_key: str,
|
| 126 |
+
# model: str,
|
| 127 |
+
# prompt: str,
|
| 128 |
+
# parameters: Optional[Dict[str, Any]] = None
|
| 129 |
+
# ) -> str:
|
| 130 |
+
# """
|
| 131 |
+
# Calls Hugging Face Inference API for text generation models (e.g., google/gemma-3-27b-it).
|
| 132 |
+
# """
|
| 133 |
+
# parameters = parameters or {
|
| 134 |
+
# "max_new_tokens": CONTEXT,
|
| 135 |
+
# "temperature": 0.2,
|
| 136 |
+
# "return_full_text": False,
|
| 137 |
+
# "repetition_penalty": 1.1,
|
| 138 |
+
# }
|
| 139 |
+
|
| 140 |
+
# url = f"https://api-inference.huggingface.co/models/{model}"
|
| 141 |
+
# headers = {"Authorization": f"Bearer {hf_api_key}"}
|
| 142 |
+
# payload = {"inputs": prompt, "parameters": parameters}
|
| 143 |
+
|
| 144 |
+
# async with httpx.AsyncClient(timeout=120) as client:
|
| 145 |
+
# r = await client.post(url, headers=headers, json=payload)
|
| 146 |
+
|
| 147 |
+
# if r.status_code == 200:
|
| 148 |
+
# data = r.json()
|
| 149 |
+
# # HF returns either a list[{"generated_text": "..."}] or a dict with error/stream info
|
| 150 |
+
# if isinstance(data, list) and data and "generated_text" in data[0]:
|
| 151 |
+
# return data[0]["generated_text"].strip()
|
| 152 |
+
# # Some pipelines return dict with "generated_text"
|
| 153 |
+
# if isinstance(data, dict) and "generated_text" in data:
|
| 154 |
+
# return data["generated_text"].strip()
|
| 155 |
+
# # Some models return plain string
|
| 156 |
+
# if isinstance(data, str):
|
| 157 |
+
# return data.strip()
|
| 158 |
+
# raise HTTPException(502, f"Unexpected HF response format: {data}")
|
| 159 |
+
# elif r.status_code == 503:
|
| 160 |
+
# # Model is loading or warming up
|
| 161 |
+
# raise HTTPException(503, "Hugging Face model is loading. Please retry.")
|
| 162 |
+
# else:
|
| 163 |
+
# raise HTTPException(r.status_code, f"Hugging Face error: {r.text}")
|
| 164 |
+
|
| 165 |
+
async def call_hf_chat(hf_api_key: str, model: str, messages, *, provider: str | None = "auto",
|
| 166 |
+
max_tokens: int = 1024, temperature: float = 0.2) -> str:
|
| 167 |
"""
|
| 168 |
+
Uses Hugging Face Inference Providers (OpenAI-compatible chat completions).
|
| 169 |
"""
|
| 170 |
+
client = InferenceClient(api_key=hf_api_key, provider=provider, timeout=120)
|
| 171 |
+
resp = client.chat.completions.create(
|
| 172 |
+
model=model,
|
| 173 |
+
messages=messages, # [{"role":"user","content":"..."}] OR multimodal structure
|
| 174 |
+
max_tokens=max_tokens,
|
| 175 |
+
temperature=temperature,
|
| 176 |
+
stream=False,
|
| 177 |
+
)
|
| 178 |
+
# hf client returns OpenAI-style response
|
| 179 |
+
return resp.choices[0].message["content"].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# ---------- FastAPI ----------
|
| 182 |
app = FastAPI(title="Voice Agent API", version="0.2.0")
|
|
|
|
| 255 |
text = await call_gemini(gemini_key, model, prompt)
|
| 256 |
return ChatOut(answer=text or "Sorry, I didn't catch that.")
|
| 257 |
|
| 258 |
+
# elif provider == "huggingface":
|
| 259 |
+
# model = payload.model or os.getenv("DEFAULT_HF_MODEL", "google/gemma-3-27b-it")
|
| 260 |
+
# # choose key from body > header (X-Hf-Api-Key) > Authorization Bearer > env
|
| 261 |
+
# hf_key = payload.hf_api_key or x_hf_api_key
|
| 262 |
+
# if not hf_key and authorization and authorization.lower().startswith("bearer "):
|
| 263 |
+
# hf_key = authorization.split(" ", 1)[1].strip()
|
| 264 |
+
# if not hf_key:
|
| 265 |
+
# hf_key = os.getenv("HF_API_KEY")
|
| 266 |
+
# if not hf_key:
|
| 267 |
+
# raise HTTPException(400, "Hugging Face API key is required (send hf_api_key, X-Hf-Api-Key, or Authorization: Bearer).")
|
| 268 |
+
# text = await call_huggingface_inference(hf_key, model, prompt)
|
| 269 |
+
# return ChatOut(answer=text or "Sorry, I didn't catch that.")
|
| 270 |
elif provider == "huggingface":
|
| 271 |
model = payload.model or os.getenv("DEFAULT_HF_MODEL", "google/gemma-3-27b-it")
|
| 272 |
+
hf_key = payload.hf_api_key or x_hf_api_key or (authorization.split(" ",1)[1].strip() if authorization and authorization.lower().startswith("bearer ") else None) or os.getenv("HF_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
if not hf_key:
|
| 274 |
raise HTTPException(400, "Hugging Face API key is required (send hf_api_key, X-Hf-Api-Key, or Authorization: Bearer).")
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
messages = [{"role":"user","content": build_prompt(payload.question)}]
|
| 277 |
+
text = await call_hf_chat(hf_key, model, messages, provider="auto")
|
| 278 |
+
return ChatOut(answer=text or "Sorry, I didn't catch that.")
|
| 279 |
+
|
| 280 |
else:
|
| 281 |
raise HTTPException(400, f"Unknown provider: {provider}")
|
| 282 |
|