Spaces:
Running
Running
RAG test1
Browse files- api/main.py +49 -0
- hybrid/assistant.py +15 -34
- models/llm.py +18 -26
api/main.py
CHANGED
|
@@ -85,6 +85,17 @@ class HybridQueryRequest(BaseModel):
|
|
| 85 |
use_web_fallback: bool = True
|
| 86 |
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
# NEW: Speech-to-Text Models
|
| 89 |
class TranscribeRequest(BaseModel):
|
| 90 |
audio_filename: str
|
|
@@ -512,6 +523,44 @@ async def hybrid_query(request: HybridQueryRequest):
|
|
| 512 |
raise HTTPException(status_code=500, detail=str(e))
|
| 513 |
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
# ============================================================================
|
| 516 |
# VOICE-TO-TEXT ENDPOINTS (NEW)
|
| 517 |
# ============================================================================
|
|
|
|
| 85 |
use_web_fallback: bool = True
|
| 86 |
|
| 87 |
|
| 88 |
+
# Fast endpoints for Node-side orchestration
|
| 89 |
+
class EmbedRequest(BaseModel):
|
| 90 |
+
text: str
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class GenerateRequest(BaseModel):
|
| 94 |
+
query: str
|
| 95 |
+
context: str
|
| 96 |
+
source_type: str = "documents" # "documents" | "web"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
# NEW: Speech-to-Text Models
|
| 100 |
class TranscribeRequest(BaseModel):
|
| 101 |
audio_filename: str
|
|
|
|
| 523 |
raise HTTPException(status_code=500, detail=str(e))
|
| 524 |
|
| 525 |
|
| 526 |
+
# ============================================================================
|
| 527 |
+
# FAST PRIMITIVE ENDPOINTS (used by Node backend for server-side RAG)
|
| 528 |
+
# ============================================================================
|
| 529 |
+
|
| 530 |
+
@app.post("/embed")
|
| 531 |
+
async def embed_text(request: EmbedRequest):
|
| 532 |
+
"""
|
| 533 |
+
Embed a single text string and return its float vector.
|
| 534 |
+
Uses only the sentence-transformer (fast, no LLM needed).
|
| 535 |
+
"""
|
| 536 |
+
try:
|
| 537 |
+
from models.embeddings import get_embedding_model
|
| 538 |
+
embedding_model = get_embedding_model()
|
| 539 |
+
vector = embedding_model.encode_query(request.text)
|
| 540 |
+
return {"embedding": vector.tolist(), "dimension": len(vector)}
|
| 541 |
+
except Exception as e:
|
| 542 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
@app.post("/generate")
|
| 546 |
+
async def generate_answer(request: GenerateRequest):
|
| 547 |
+
"""
|
| 548 |
+
Generate a short answer given pre-built context.
|
| 549 |
+
Called by the Node backend after it has already done retrieval from MongoDB.
|
| 550 |
+
Much faster than /assistant because no retrieval step happens here.
|
| 551 |
+
"""
|
| 552 |
+
try:
|
| 553 |
+
assistant = get_hybrid_assistant_instance()
|
| 554 |
+
answer = assistant._generate_answer(
|
| 555 |
+
query=request.query,
|
| 556 |
+
context=request.context,
|
| 557 |
+
source_type=request.source_type,
|
| 558 |
+
)
|
| 559 |
+
return {"answer": answer}
|
| 560 |
+
except Exception as e:
|
| 561 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 562 |
+
|
| 563 |
+
|
| 564 |
# ============================================================================
|
| 565 |
# VOICE-TO-TEXT ENDPOINTS (NEW)
|
| 566 |
# ============================================================================
|
hybrid/assistant.py
CHANGED
|
@@ -126,46 +126,27 @@ class HybridAssistant:
|
|
| 126 |
context: str,
|
| 127 |
source_type: str
|
| 128 |
) -> str:
|
| 129 |
-
"""Generate answer from context"""
|
| 130 |
-
|
| 131 |
-
if source_type == "documents":
|
| 132 |
-
prompt = f"""You are a helpful AI assistant. Answer the question using ONLY the information from the provided context.
|
| 133 |
-
|
| 134 |
-
Context from uploaded documents:
|
| 135 |
-
{context}
|
| 136 |
-
|
| 137 |
-
Question: {query}
|
| 138 |
-
|
| 139 |
-
Instructions:
|
| 140 |
-
- Answer based on the context above
|
| 141 |
-
- Cite sources using [Source 1], [Source 2], etc.
|
| 142 |
-
- If the context doesn't fully answer the question, say so
|
| 143 |
-
- Be concise and accurate
|
| 144 |
-
|
| 145 |
-
Answer:"""
|
| 146 |
-
|
| 147 |
-
else: # web sources
|
| 148 |
-
prompt = f"""You are a helpful AI assistant. Answer the question using the information from web search results.
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
|
| 161 |
-
Answer:"""
|
| 162 |
-
|
| 163 |
response = self.llm.generate(
|
| 164 |
prompt=prompt,
|
| 165 |
-
max_new_tokens=
|
| 166 |
-
temperature=0.7
|
| 167 |
)
|
| 168 |
-
|
| 169 |
return response.strip()
|
| 170 |
|
| 171 |
# Singleton
|
|
|
|
| 126 |
context: str,
|
| 127 |
source_type: str
|
| 128 |
) -> str:
|
| 129 |
+
"""Generate answer from context using TinyLlama chat format for speed."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
# TinyLlama chat template — keeps token count minimal for fast inference
|
| 132 |
+
if source_type == "documents":
|
| 133 |
+
system = "You are a helpful study assistant. Answer ONLY using the provided context. Cite [Source N] inline."
|
| 134 |
+
user_content = f"Context:\n{context[:1500]}\n\nQuestion: {query}"
|
| 135 |
+
else:
|
| 136 |
+
system = "You are a helpful assistant. Summarise the web results to answer the question concisely. Cite [Web N] inline."
|
| 137 |
+
user_content = f"Web results:\n{context[:1500]}\n\nQuestion: {query}"
|
| 138 |
|
| 139 |
+
prompt = (
|
| 140 |
+
f"<|system|>\n{system}</s>\n"
|
| 141 |
+
f"<|user|>\n{user_content}</s>\n"
|
| 142 |
+
f"<|assistant|>\n"
|
| 143 |
+
)
|
| 144 |
|
|
|
|
|
|
|
| 145 |
response = self.llm.generate(
|
| 146 |
prompt=prompt,
|
| 147 |
+
max_new_tokens=150,
|
|
|
|
| 148 |
)
|
| 149 |
+
|
| 150 |
return response.strip()
|
| 151 |
|
| 152 |
# Singleton
|
models/llm.py
CHANGED
|
@@ -60,43 +60,35 @@ class LanguageModel:
|
|
| 60 |
def generate(
|
| 61 |
self,
|
| 62 |
prompt: str,
|
| 63 |
-
max_new_tokens: int =
|
| 64 |
temperature: float = TEMPERATURE,
|
| 65 |
top_p: float = TOP_P
|
| 66 |
) -> str:
|
| 67 |
"""
|
| 68 |
-
Generate text from prompt
|
| 69 |
-
|
| 70 |
-
Args:
|
| 71 |
-
prompt: Input prompt
|
| 72 |
-
max_new_tokens: Maximum tokens to generate
|
| 73 |
-
temperature: Sampling temperature
|
| 74 |
-
top_p: Top-p sampling
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
Generated text
|
| 78 |
"""
|
| 79 |
-
inputs = self.tokenizer(
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
with torch.no_grad():
|
| 82 |
outputs = self.model.generate(
|
| 83 |
**inputs,
|
| 84 |
max_new_tokens=max_new_tokens,
|
| 85 |
-
|
| 86 |
-
top_p=top_p,
|
| 87 |
-
do_sample=True,
|
| 88 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 89 |
-
eos_token_id=self.tokenizer.eos_token_id
|
|
|
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
# Decode
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
generated_text = generated_text[len(prompt):].strip()
|
| 98 |
-
|
| 99 |
-
return generated_text
|
| 100 |
|
| 101 |
# Singleton instance
|
| 102 |
_llm_model = None
|
|
|
|
| 60 |
def generate(
|
| 61 |
self,
|
| 62 |
prompt: str,
|
| 63 |
+
max_new_tokens: int = 150,
|
| 64 |
temperature: float = TEMPERATURE,
|
| 65 |
top_p: float = TOP_P
|
| 66 |
) -> str:
|
| 67 |
"""
|
| 68 |
+
Generate text from prompt using greedy decoding for speed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"""
|
| 70 |
+
inputs = self.tokenizer(
|
| 71 |
+
prompt,
|
| 72 |
+
return_tensors="pt",
|
| 73 |
+
truncation=True,
|
| 74 |
+
max_length=1024, # cap input to avoid OOM and slow processing
|
| 75 |
+
).to(self.model.device)
|
| 76 |
+
|
| 77 |
with torch.no_grad():
|
| 78 |
outputs = self.model.generate(
|
| 79 |
**inputs,
|
| 80 |
max_new_tokens=max_new_tokens,
|
| 81 |
+
do_sample=False, # greedy — ~3x faster than sampling
|
|
|
|
|
|
|
| 82 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 83 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 84 |
+
repetition_penalty=1.1, # avoid repetition loops
|
| 85 |
)
|
| 86 |
+
|
| 87 |
+
# Decode only the newly generated tokens (skip input)
|
| 88 |
+
input_len = inputs["input_ids"].shape[1]
|
| 89 |
+
generated_ids = outputs[0][input_len:]
|
| 90 |
+
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 91 |
+
return generated_text.strip()
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Singleton instance
|
| 94 |
_llm_model = None
|