Mohammad Wasil commited on
Commit
eb597aa
·
1 Parent(s): 62a2bc4

updating frontend

Browse files
Files changed (2) hide show
  1. main.py +306 -86
  2. requirements.txt +1 -0
main.py CHANGED
@@ -1,106 +1,326 @@
1
- import uuid
2
- import json
3
- import asyncio
4
- import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
- import sys
 
 
7
  from contextlib import asynccontextmanager
8
- from loguru import logger
9
- from fastapi import FastAPI, HTTPException, status, Response
10
- from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi.staticfiles import StaticFiles
12
- from fastapi.responses import HTMLResponse
13
-
14
- # Import your existing schemas (Ensure schemas.py is in the same folder)
15
- from schemas import ChatRequest, ChatResponse
16
-
17
- # -------------------------------------------------
18
- # 1. Loguru Configuration
19
- # -------------------------------------------------
20
- logger.remove()
21
- logger.add(sys.stdout, format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | <cyan>{extra[session_id]}</cyan> - {message}")
22
- logger = logger.bind(session_id="SYSTEM")
23
-
24
- # -------------------------------------------------
25
- # 2. AI Logic (Replacing the MQTT Worker)
26
- # -------------------------------------------------
27
- # We define a direct function instead of publishing to MQTT
28
- async def get_ai_response(question: str):
29
- """
30
- Replace this with your actual agent logic (e.g., LangChain or Groq).
31
- This simulates what your 'worker' used to do.
32
- """
33
- # Simulate processing time
34
- await asyncio.sleep(1)
35
- return {
36
- "answer": f"I am your SmartCoffee assistant. You asked: {question}",
37
- "sources": ["knowledge_base_v1"],
38
- "timestamp": time.time()
39
- }
40
 
41
- # -------------------------------------------------
42
- # 3. App Lifespan
43
- # -------------------------------------------------
 
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
46
- logger.info("Starting AI Agent on Hugging Face...")
 
 
47
  yield
48
- logger.info("Shutting down...")
49
-
50
- # -------------------------------------------------
51
- # 4. App Init
52
- # -------------------------------------------------
53
- app = FastAPI(title="SmartCoffee AI 2026", lifespan=lifespan)
54
-
55
- # Allow CORS for local testing, though HF uses same-origin
56
- app.add_middleware(
57
- CORSMiddleware,
58
- allow_origins=["*"],
59
- allow_methods=["*"],
60
- allow_headers=["*"],
61
  )
62
 
63
- # --- CRITICAL: Mount Static Files ---
64
- # This serves your index.html, CSS, and JS
65
- app.mount("/static", StaticFiles(directory="static"), name="static")
66
 
67
- # -------------------------------------------------
68
- # 5. Routes
69
- # -------------------------------------------------
 
 
 
70
 
71
- @app.get("/", response_class=HTMLResponse)
72
- async def serve_frontend():
73
- """Serves the main chat interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
- with open("static/index.html", "r", encoding="utf-8") as f:
76
- return HTMLResponse(content=f.read())
77
- except FileNotFoundError:
78
- return HTMLResponse(content="<h1>index.html not found in /static</h1>", status_code=404)
 
 
 
 
79
 
80
- @app.post("/api/v1/chat", response_model=ChatResponse)
81
- async def chat(request: ChatRequest):
82
- if request.session_id == "default":
83
- request.session_id = f"hf_{uuid.uuid4().hex[:12]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- request_logger = logger.bind(session_id=request.session_id)
86
- request_logger.info(f"Processing request: {request.question}")
 
 
 
87
 
 
 
 
 
 
 
 
 
 
88
  try:
89
- # Instead of MQTT publish, call logic directly
90
- response = await get_ai_response(request.question)
91
-
92
- request_logger.success("Response generated.")
93
- return ChatResponse(
94
- question=request.question,
95
- answer=response["answer"],
96
- sources=response.get("sources", []),
97
- session_id=request.session_id,
98
- timestamp=response.get("timestamp", time.time()),
99
  )
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
- request_logger.error(f"Error: {str(e)}")
102
- raise HTTPException(status_code=500, detail="Internal AI Error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  @app.get("/health")
105
  async def health():
106
- return {"status": "healthy", "platform": "Hugging Face"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import uuid
2
+ # import json
3
+ # import asyncio
4
+ # import time
5
+ # import os
6
+ # import sys
7
+ # from contextlib import asynccontextmanager
8
+ # from loguru import logger
9
+ # from fastapi import FastAPI, HTTPException, status, Response
10
+ # from fastapi.middleware.cors import CORSMiddleware
11
+ # from fastapi.staticfiles import StaticFiles
12
+ # from fastapi.responses import HTMLResponse
13
+
14
+ # # Import your existing schemas (Ensure schemas.py is in the same folder)
15
+ # from schemas import ChatRequest, ChatResponse
16
+
17
+ # # -------------------------------------------------
18
+ # # 1. Loguru Configuration
19
+ # # -------------------------------------------------
20
+ # logger.remove()
21
+ # logger.add(sys.stdout, format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | <cyan>{extra[session_id]}</cyan> - {message}")
22
+ # logger = logger.bind(session_id="SYSTEM")
23
+
24
+ # # -------------------------------------------------
25
+ # # 2. AI Logic (Replacing the MQTT Worker)
26
+ # # -------------------------------------------------
27
+ # # We define a direct function instead of publishing to MQTT
28
+ # async def get_ai_response(question: str):
29
+ # """
30
+ # Replace this with your actual agent logic (e.g., LangChain or Groq).
31
+ # This simulates what your 'worker' used to do.
32
+ # """
33
+ # # Simulate processing time
34
+ # await asyncio.sleep(1)
35
+ # return {
36
+ # "answer": f"I am your SmartCoffee assistant. You asked: {question}",
37
+ # "sources": ["knowledge_base_v1"],
38
+ # "timestamp": time.time()
39
+ # }
40
+
41
+ # # -------------------------------------------------
42
+ # # 3. App Lifespan
43
+ # # -------------------------------------------------
44
+ # @asynccontextmanager
45
+ # async def lifespan(app: FastAPI):
46
+ # logger.info("Starting AI Agent on Hugging Face...")
47
+ # yield
48
+ # logger.info("Shutting down...")
49
+
50
+ # # -------------------------------------------------
51
+ # # 4. App Init
52
+ # # -------------------------------------------------
53
+ # app = FastAPI(title="SmartCoffee AI 2026", lifespan=lifespan)
54
+
55
+ # # Allow CORS for local testing, though HF uses same-origin
56
+ # app.add_middleware(
57
+ # CORSMiddleware,
58
+ # allow_origins=["*"],
59
+ # allow_methods=["*"],
60
+ # allow_headers=["*"],
61
+ # )
62
+
63
+ # # --- CRITICAL: Mount Static Files ---
64
+ # # This serves your index.html, CSS, and JS
65
+ # app.mount("/static", StaticFiles(directory="static"), name="static")
66
+
67
+ # # -------------------------------------------------
68
+ # # 5. Routes
69
+ # # -------------------------------------------------
70
+
71
+ # @app.get("/", response_class=HTMLResponse)
72
+ # async def serve_frontend():
73
+ # """Serves the main chat interface"""
74
+ # try:
75
+ # with open("static/index.html", "r", encoding="utf-8") as f:
76
+ # return HTMLResponse(content=f.read())
77
+ # except FileNotFoundError:
78
+ # return HTMLResponse(content="<h1>index.html not found in /static</h1>", status_code=404)
79
+
80
+ # @app.post("/api/v1/chat", response_model=ChatResponse)
81
+ # async def chat(request: ChatRequest):
82
+ # if request.session_id == "default":
83
+ # request.session_id = f"hf_{uuid.uuid4().hex[:12]}"
84
+
85
+ # request_logger = logger.bind(session_id=request.session_id)
86
+ # request_logger.info(f"Processing request: {request.question}")
87
+
88
+ # try:
89
+ # # Instead of MQTT publish, call logic directly
90
+ # response = await get_ai_response(request.question)
91
+
92
+ # request_logger.success("Response generated.")
93
+ # return ChatResponse(
94
+ # question=request.question,
95
+ # answer=response["answer"],
96
+ # sources=response.get("sources", []),
97
+ # session_id=request.session_id,
98
+ # timestamp=response.get("timestamp", time.time()),
99
+ # )
100
+ # except Exception as e:
101
+ # request_logger.error(f"Error: {str(e)}")
102
+ # raise HTTPException(status_code=500, detail="Internal AI Error")
103
+
104
+ # @app.get("/health")
105
+ # async def health():
106
+ # return {"status": "healthy", "platform": "Hugging Face"}
107
+
108
+
109
+
110
+
111
+ from fastapi import FastAPI, Request, HTTPException
112
+ from fastapi.responses import HTMLResponse, RedirectResponse
113
+ from fastapi.staticfiles import StaticFiles
114
+ from pydantic import BaseModel, Field, field_validator, validator
115
  import os
116
+ import re
117
+ import time
118
+ import uuid
119
  from contextlib import asynccontextmanager
120
+ import logging
121
+
122
+ # Logging setup
123
+ logging.basicConfig(level=logging.INFO)
124
+ logger = logging.getLogger(__name__)
125
+
126
+ # Space-specific: Use mounted dataset path
127
+ KB_PATH = "/data/knowledge_base"
128
+
129
+ # Groq client setup
130
+ from groq import Groq
131
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Space hardware: CPU-basic, limit memory
134
+ MAX_SESSIONS = 50 # Lower for free tier
135
+
136
+ # Lifespan for startup/shutdown
137
  @asynccontextmanager
138
  async def lifespan(app: FastAPI):
139
+ logger.info("🚀 Starting up agent...")
140
+ # Load knowledge base here
141
+ await load_knowledge_base()
142
  yield
143
+ logger.info("🔌 Shutting down agent...")
144
+
145
+ app = FastAPI(
146
+ title="SmartCoffee AI Agent",
147
+ description="AI Support Agent - Hugging Face Spaces Edition",
148
+ version="1.0.0",
149
+ lifespan=lifespan
 
 
 
 
 
 
150
  )
151
 
152
+ # Mount static files (CSS/JS)
153
+ app.mount("/static", StaticFiles(directory="."), name="static")
 
154
 
155
+ # Pydantic models
156
+ class ChatRequest(BaseModel):
157
+ question: str = Field(..., min_length=3, max_length=300)
158
+ session_id: str = Field(default="default", pattern=r"^[a-zA-Z0-9_-]+$")
159
+
160
+ question: str
161
 
162
+ @field_validator('question')
163
+ @classmethod
164
+ def sanitize_input(cls, v: str) -> str:
165
+ # Standardize whitespace and strip
166
+ v = re.sub(r'\s+', ' ', v).strip()
167
+
168
+ # Security check for prompt injection keywords
169
+ forbidden_keywords = ['ignore', 'system', 'admin', 'prompt']
170
+ if any(word in v.lower() for word in forbidden_keywords):
171
+ raise ValueError("Invalid input pattern")
172
+
173
+ return v
174
+
175
+ # In-memory session store (no Redis in free tier)
176
+ sessions = {}
177
+
178
+ async def load_knowledge_base():
179
+ """Load knowledge base from HF dataset at startup"""
180
+ from datasets import load_dataset
181
+
182
+ logger.info("📚 Loading knowledge base...")
183
  try:
184
+ dataset = load_dataset("YOUR_USERNAME/smartcoffee-kb", split="train")
185
+ # Process into text chunks
186
+ global knowledge_docs
187
+ knowledge_docs = [doc["text"] for doc in dataset]
188
+ logger.info(f"✅ Loaded {len(knowledge_docs)} documents")
189
+ except Exception as e:
190
+ logger.error(f"❌ Failed to load KB: {e}")
191
+ knowledge_docs = []
192
 
193
+ # RAG function
194
+ def rag_query(question: str) -> str:
195
+ from langchain_huggingface import HuggingFaceEmbeddings
196
+ from sklearn.metrics.pairwise import cosine_similarity
197
+ import numpy as np
198
+
199
+ if not knowledge_docs:
200
+ return "Knowledge base not loaded."
201
+
202
+ # Simple TF-IDF search (memory-efficient)
203
+ from sklearn.feature_extraction.text import TfidfVectorizer
204
+
205
+ vectorizer = TfidfVectorizer(max_features=1000, stop_words='english')
206
+ doc_vectors = vectorizer.fit_transform(knowledge_docs)
207
+ question_vec = vectorizer.transform([question])
208
+
209
+ # Get top 2 most similar docs
210
+ similarities = cosine_similarity(question_vec, doc_vectors).flatten()
211
+ top_indices = np.argsort(similarities)[-2:]
212
+
213
+ context = "\n\n".join([knowledge_docs[i] for i in top_indices])
214
+ return context
215
 
216
+ # LLM call
217
+ def generate_response(question: str, context: str, session_id: str) -> dict:
218
+ start_time = time.time()
219
+
220
+ prompt = f"""You are SmartCoffee Support AI. Use ONLY this context:
221
 
222
+ Context:
223
+ {context}
224
+
225
+ Question: {question}
226
+
227
+ Answer concisely in 2-3 sentences. If unsure, say "I need to check with my team."
228
+
229
+ Answer:"""
230
+
231
  try:
232
+ response = client.chat.completions.create(
233
+ model="llama3-8b-8192",
234
+ messages=[{"role": "user", "content": prompt}],
235
+ max_tokens=200,
236
+ temperature=0.1
 
 
 
 
 
237
  )
238
+
239
+ latency = time.time() - start_time
240
+
241
+ return {
242
+ "answer": response.choices[0].message.content,
243
+ "latency": latency,
244
+ "tokens_in": response.usage.prompt_tokens,
245
+ "tokens_out": response.usage.completion_tokens,
246
+ "model": "groq-llama3-8b",
247
+ "sources": [f"doc_{i}" for i in range(2)]
248
+ }
249
+
250
  except Exception as e:
251
+ logger.error(f"LLM error: {e}")
252
+ return {
253
+ "answer": "Sorry, I'm having trouble processing your request.",
254
+ "latency": time.time() - start_time,
255
+ "error": str(e)
256
+ }
257
+
258
+ # Routes
259
+ @app.get("/", response_class=HTMLResponse)
260
+ async def serve_frontend():
261
+ """Serve the combined frontend"""
262
+ with open("index.html", "r", encoding="utf-8") as f:
263
+ return HTMLResponse(content=f.read())
264
+
265
+ @app.post("/api/v1/chat")
266
+ async def chat(request: ChatRequest):
267
+ try:
268
+ # Get session memory
269
+ session = sessions.get(request.session_id, {
270
+ "history": [],
271
+ "created_at": time.time()
272
+ })
273
+
274
+ # Clean up old sessions
275
+ if len(sessions) > MAX_SESSIONS:
276
+ oldest = min(sessions, key=lambda k: sessions[k]["created_at"])
277
+ del sessions[oldest]
278
+
279
+ # Add user message to history
280
+ session["history"].append({"role": "user", "content": request.question})
281
+
282
+ # RAG query
283
+ context = rag_query(request.question)
284
+
285
+ # Generate response
286
+ result = generate_response(request.question, context, request.session_id)
287
+
288
+ # Add bot message to history
289
+ session["history"].append({"role": "bot", "content": result["answer"]})
290
+ sessions[request.session_id] = session
291
+
292
+ return {
293
+ "question": request.question,
294
+ "answer": result["answer"],
295
+ "sources": result.get("sources", []),
296
+ "session_id": request.session_id,
297
+ "latency_ms": int(result["latency"] * 1000)
298
+ }
299
+
300
+ except ValueError as e:
301
+ raise HTTPException(status_code=400, detail=str(e))
302
+ except Exception as e:
303
+ logger.error(f"Unexpected error: {e}")
304
+ raise HTTPException(status_code=500, detail="Failed to process request")
305
 
306
  @app.get("/health")
307
  async def health():
308
+ return {
309
+ "status": "operational",
310
+ "sessions_active": len(sessions),
311
+ "kb_loaded": len(knowledge_docs) if 'knowledge_docs' in globals() else 0
312
+ }
313
+
314
+ @app.get("/api/v1/metrics")
315
+ async def metrics():
316
+ """Simple metrics endpoint"""
317
+ return {
318
+ "total_requests": sum(len(s.get("history", [])) for s in sessions.values()) // 2,
319
+ "active_sessions": len(sessions),
320
+ "uptime_seconds": int(time.time() - app.state.startup_time)
321
+ }
322
+
323
+
324
+ @app.get("/")
325
+ async def root():
326
+ return {"message": "Agent is running", "uptime": time.time() - app.state.startup_time}
requirements.txt CHANGED
@@ -30,3 +30,4 @@ prometheus-client==0.23.1
30
 
31
  #loguru
32
  loguru==0.7.3
 
 
30
 
31
  #loguru
32
  loguru==0.7.3
33
+ datasets