Jay-10020 commited on
Commit
4e4501d
·
1 Parent(s): c766e4c

RAG test1

Browse files
Files changed (3) hide show
  1. api/main.py +49 -0
  2. hybrid/assistant.py +15 -34
  3. models/llm.py +18 -26
api/main.py CHANGED
@@ -85,6 +85,17 @@ class HybridQueryRequest(BaseModel):
85
  use_web_fallback: bool = True
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
88
  # NEW: Speech-to-Text Models
89
  class TranscribeRequest(BaseModel):
90
  audio_filename: str
@@ -512,6 +523,44 @@ async def hybrid_query(request: HybridQueryRequest):
512
  raise HTTPException(status_code=500, detail=str(e))
513
 
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  # ============================================================================
516
  # VOICE-TO-TEXT ENDPOINTS (NEW)
517
  # ============================================================================
 
85
  use_web_fallback: bool = True
86
 
87
 
88
+ # Fast endpoints for Node-side orchestration
89
+ class EmbedRequest(BaseModel):
90
+ text: str
91
+
92
+
93
+ class GenerateRequest(BaseModel):
94
+ query: str
95
+ context: str
96
+ source_type: str = "documents" # "documents" | "web"
97
+
98
+
99
  # NEW: Speech-to-Text Models
100
  class TranscribeRequest(BaseModel):
101
  audio_filename: str
 
523
  raise HTTPException(status_code=500, detail=str(e))
524
 
525
 
526
+ # ============================================================================
527
+ # FAST PRIMITIVE ENDPOINTS (used by Node backend for server-side RAG)
528
+ # ============================================================================
529
+
530
+ @app.post("/embed")
531
+ async def embed_text(request: EmbedRequest):
532
+ """
533
+ Embed a single text string and return its float vector.
534
+ Uses only the sentence-transformer (fast, no LLM needed).
535
+ """
536
+ try:
537
+ from models.embeddings import get_embedding_model
538
+ embedding_model = get_embedding_model()
539
+ vector = embedding_model.encode_query(request.text)
540
+ return {"embedding": vector.tolist(), "dimension": len(vector)}
541
+ except Exception as e:
542
+ raise HTTPException(status_code=500, detail=str(e))
543
+
544
+
545
+ @app.post("/generate")
546
+ async def generate_answer(request: GenerateRequest):
547
+ """
548
+ Generate a short answer given pre-built context.
549
+ Called by the Node backend after it has already done retrieval from MongoDB.
550
+ Much faster than /assistant because no retrieval step happens here.
551
+ """
552
+ try:
553
+ assistant = get_hybrid_assistant_instance()
554
+ answer = assistant._generate_answer(
555
+ query=request.query,
556
+ context=request.context,
557
+ source_type=request.source_type,
558
+ )
559
+ return {"answer": answer}
560
+ except Exception as e:
561
+ raise HTTPException(status_code=500, detail=str(e))
562
+
563
+
564
  # ============================================================================
565
  # VOICE-TO-TEXT ENDPOINTS (NEW)
566
  # ============================================================================
hybrid/assistant.py CHANGED
@@ -126,46 +126,27 @@ class HybridAssistant:
126
  context: str,
127
  source_type: str
128
  ) -> str:
129
- """Generate answer from context"""
130
-
131
- if source_type == "documents":
132
- prompt = f"""You are a helpful AI assistant. Answer the question using ONLY the information from the provided context.
133
-
134
- Context from uploaded documents:
135
- {context}
136
-
137
- Question: {query}
138
-
139
- Instructions:
140
- - Answer based on the context above
141
- - Cite sources using [Source 1], [Source 2], etc.
142
- - If the context doesn't fully answer the question, say so
143
- - Be concise and accurate
144
-
145
- Answer:"""
146
-
147
- else: # web sources
148
- prompt = f"""You are a helpful AI assistant. Answer the question using the information from web search results.
149
 
150
- Web search results:
151
- {context}
152
-
153
- Question: {query}
 
 
 
154
 
155
- Instructions:
156
- - Synthesize information from the web sources
157
- - Cite sources using [Web Source 1], [Web Source 2], etc.
158
- - Provide accurate and helpful information
159
- - Be concise
160
 
161
- Answer:"""
162
-
163
  response = self.llm.generate(
164
  prompt=prompt,
165
- max_new_tokens=512,
166
- temperature=0.7
167
  )
168
-
169
  return response.strip()
170
 
171
  # Singleton
 
126
  context: str,
127
  source_type: str
128
  ) -> str:
129
+ """Generate answer from context using TinyLlama chat format for speed."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # TinyLlama chat template — keeps token count minimal for fast inference
132
+ if source_type == "documents":
133
+ system = "You are a helpful study assistant. Answer ONLY using the provided context. Cite [Source N] inline."
134
+ user_content = f"Context:\n{context[:1500]}\n\nQuestion: {query}"
135
+ else:
136
+ system = "You are a helpful assistant. Summarise the web results to answer the question concisely. Cite [Web N] inline."
137
+ user_content = f"Web results:\n{context[:1500]}\n\nQuestion: {query}"
138
 
139
+ prompt = (
140
+ f"<|system|>\n{system}</s>\n"
141
+ f"<|user|>\n{user_content}</s>\n"
142
+ f"<|assistant|>\n"
143
+ )
144
 
 
 
145
  response = self.llm.generate(
146
  prompt=prompt,
147
+ max_new_tokens=150,
 
148
  )
149
+
150
  return response.strip()
151
 
152
  # Singleton
models/llm.py CHANGED
@@ -60,43 +60,35 @@ class LanguageModel:
60
  def generate(
61
  self,
62
  prompt: str,
63
- max_new_tokens: int = MAX_NEW_TOKENS,
64
  temperature: float = TEMPERATURE,
65
  top_p: float = TOP_P
66
  ) -> str:
67
  """
68
- Generate text from prompt
69
-
70
- Args:
71
- prompt: Input prompt
72
- max_new_tokens: Maximum tokens to generate
73
- temperature: Sampling temperature
74
- top_p: Top-p sampling
75
-
76
- Returns:
77
- Generated text
78
  """
79
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
80
-
 
 
 
 
 
81
  with torch.no_grad():
82
  outputs = self.model.generate(
83
  **inputs,
84
  max_new_tokens=max_new_tokens,
85
- temperature=temperature,
86
- top_p=top_p,
87
- do_sample=True,
88
  pad_token_id=self.tokenizer.pad_token_id,
89
- eos_token_id=self.tokenizer.eos_token_id
 
90
  )
91
-
92
- # Decode and remove input prompt
93
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
94
-
95
- # Remove the input prompt from output
96
- if generated_text.startswith(prompt):
97
- generated_text = generated_text[len(prompt):].strip()
98
-
99
- return generated_text
100
 
101
  # Singleton instance
102
  _llm_model = None
 
60
  def generate(
61
  self,
62
  prompt: str,
63
+ max_new_tokens: int = 150,
64
  temperature: float = TEMPERATURE,
65
  top_p: float = TOP_P
66
  ) -> str:
67
  """
68
+ Generate text from prompt using greedy decoding for speed.
 
 
 
 
 
 
 
 
 
69
  """
70
+ inputs = self.tokenizer(
71
+ prompt,
72
+ return_tensors="pt",
73
+ truncation=True,
74
+ max_length=1024, # cap input to avoid OOM and slow processing
75
+ ).to(self.model.device)
76
+
77
  with torch.no_grad():
78
  outputs = self.model.generate(
79
  **inputs,
80
  max_new_tokens=max_new_tokens,
81
+ do_sample=False, # greedy — ~3x faster than sampling
 
 
82
  pad_token_id=self.tokenizer.pad_token_id,
83
+ eos_token_id=self.tokenizer.eos_token_id,
84
+ repetition_penalty=1.1, # avoid repetition loops
85
  )
86
+
87
+ # Decode only the newly generated tokens (skip input)
88
+ input_len = inputs["input_ids"].shape[1]
89
+ generated_ids = outputs[0][input_len:]
90
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
91
+ return generated_text.strip()
 
 
 
92
 
93
  # Singleton instance
94
  _llm_model = None