yukee1992 commited on
Commit
cdf7c7c
·
verified ·
1 Parent(s): e84a89e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +361 -0
app.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - General API endpoint
2
+ from fastapi import FastAPI, HTTPException, Depends, Header
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict, Any
5
+ import logging
6
+ from transformers import pipeline
7
+ import uvicorn
8
+ import os
9
+ from datetime import datetime
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Initialize FastAPI app
16
+ app = FastAPI(
17
+ title="AI Chat API for n8n",
18
+ description="General AI processing API that accepts prompts from n8n workflows",
19
+ version="1.0.0"
20
+ )
21
+
22
+ # Request/Response models
23
+ class PromptRequest(BaseModel):
24
+ """Request model for prompt processing"""
25
+ prompt: str # User's instruction/query
26
+ content: Optional[str] = None # Optional content to process
27
+ parameters: Optional[Dict[str, Any]] = None # Optional parameters
28
+ task_type: Optional[str] = None # Optional: summarize, generate, classify, etc.
29
+ max_length: Optional[int] = 200
30
+ temperature: Optional[float] = 0.7
31
+ return_type: Optional[str] = "text" # text, json, list
32
+
33
+ class PromptResponse(BaseModel):
34
+ """Response model"""
35
+ success: bool
36
+ result: Optional[Any] = None
37
+ error: Optional[str] = None
38
+ processing_time: Optional[float] = None
39
+ model_used: Optional[str] = None
40
+
41
+ class BatchRequest(BaseModel):
42
+ """Batch processing request"""
43
+ prompts: List[PromptRequest]
44
+ parallel: Optional[bool] = False
45
+
46
+ # Initialize models
47
+ class AIModelManager:
48
+ """Manages AI models dynamically"""
49
+ def __init__(self):
50
+ self.models = {}
51
+ self.load_models()
52
+
53
+ def load_models(self):
54
+ """Load essential models"""
55
+ try:
56
+ # Load a general text generation model
57
+ self.models["text-generation"] = pipeline(
58
+ "text-generation",
59
+ model="gpt2",
60
+ max_length=200,
61
+ device=-1 # CPU by default
62
+ )
63
+
64
+ # Load summarization model
65
+ self.models["summarization"] = pipeline(
66
+ "summarization",
67
+ model="facebook/bart-large-cnn",
68
+ device=-1
69
+ )
70
+
71
+ # Load text classification for intent detection
72
+ self.models["text-classification"] = pipeline(
73
+ "text-classification",
74
+ model="distilbert-base-uncased-finetuned-sst-2-english",
75
+ device=-1
76
+ )
77
+
78
+ logger.info("Models loaded successfully")
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error loading models: {e}")
82
+ # Create dummy models for testing
83
+ self.models = {}
84
+
85
+ def process_prompt(self, prompt: str, content: str = None, **kwargs) -> str:
86
+ """
87
+ General prompt processing method
88
+ Args:
89
+ prompt: Instruction/query from user
90
+ content: Optional content to process
91
+ **kwargs: Additional parameters
92
+ """
93
+ try:
94
+ # Combine prompt and content
95
+ full_input = prompt
96
+ if content:
97
+ full_input = f"{prompt}\n\nContent: {content}"
98
+
99
+ # Determine task type from prompt
100
+ task_type = self._detect_task_type(prompt, content)
101
+
102
+ # Process based on task type
103
+ if task_type == "summarize" and content:
104
+ return self._process_summarization(content, **kwargs)
105
+
106
+ elif task_type == "generate":
107
+ return self._process_generation(full_input, **kwargs)
108
+
109
+ elif task_type == "classify" and content:
110
+ return self._process_classification(content, **kwargs)
111
+
112
+ else:
113
+ # Default: general text generation
114
+ return self._process_generation(full_input, **kwargs)
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error processing prompt: {e}")
118
+ return f"Error processing your request: {str(e)}"
119
+
120
+ def _detect_task_type(self, prompt: str, content: str = None) -> str:
121
+ """Detect task type from prompt"""
122
+ prompt_lower = prompt.lower()
123
+
124
+ task_keywords = {
125
+ "summarize": ["summarize", "summary", "brief", "overview"],
126
+ "generate": ["generate", "create", "write", "make", "draft"],
127
+ "classify": ["classify", "categorize", "label", "tag"],
128
+ "translate": ["translate", "convert language"],
129
+ "analyze": ["analyze", "analyze", "evaluate", "assess"]
130
+ }
131
+
132
+ for task, keywords in task_keywords.items():
133
+ if any(keyword in prompt_lower for keyword in keywords):
134
+ return task
135
+
136
+ return "general"
137
+
138
+ def _process_summarization(self, content: str, **kwargs) -> str:
139
+ """Process summarization task"""
140
+ if "summarization" in self.models:
141
+ max_length = kwargs.get("max_length", 150)
142
+ min_length = kwargs.get("min_length", 30)
143
+
144
+ result = self.models["summarization"](
145
+ content,
146
+ max_length=max_length,
147
+ min_length=min_length,
148
+ do_sample=False
149
+ )
150
+ return result[0]['summary_text']
151
+ else:
152
+ # Fallback
153
+ sentences = content.split('. ')
154
+ if len(sentences) > 3:
155
+ return '. '.join(sentences[:2]) + '.'
156
+ return content[:100] + "..."
157
+
158
+ def _process_generation(self, prompt: str, **kwargs) -> str:
159
+ """Process text generation task"""
160
+ if "text-generation" in self.models:
161
+ max_length = kwargs.get("max_length", 100)
162
+ temperature = kwargs.get("temperature", 0.7)
163
+
164
+ result = self.models["text-generation"](
165
+ prompt,
166
+ max_length=max_length,
167
+ temperature=temperature,
168
+ num_return_sequences=1
169
+ )
170
+ return result[0]['generated_text']
171
+ else:
172
+ # Fallback response
173
+ return f"Processed: {prompt[:50]}... [Model not loaded]"
174
+
175
+ def _process_classification(self, content: str, **kwargs) -> str:
176
+ """Process classification task"""
177
+ if "text-classification" in self.models:
178
+ result = self.models["text-classification"](content)
179
+ return str(result)
180
+ else:
181
+ return "Classification model not available"
182
+
183
+ # Initialize model manager
184
+ model_manager = AIModelManager()
185
+
186
+ # API Endpoints
187
+ @app.get("/")
188
+ async def root():
189
+ """Root endpoint"""
190
+ return {
191
+ "status": "online",
192
+ "service": "AI Chat API for n8n",
193
+ "endpoints": {
194
+ "/process": "Process single prompt (POST)",
195
+ "/batch": "Process multiple prompts (POST)",
196
+ "/health": "Health check (GET)",
197
+ "/models": "List loaded models (GET)"
198
+ }
199
+ }
200
+
201
+ @app.get("/health")
202
+ async def health_check():
203
+ """Health check endpoint"""
204
+ return {
205
+ "status": "healthy",
206
+ "timestamp": datetime.now().isoformat(),
207
+ "models_loaded": len(model_manager.models) > 0
208
+ }
209
+
210
+ @app.get("/models")
211
+ async def list_models():
212
+ """List loaded models"""
213
+ return {
214
+ "models": list(model_manager.models.keys()),
215
+ "count": len(model_manager.models)
216
+ }
217
+
218
+ @app.post("/process", response_model=PromptResponse)
219
+ async def process_prompt(request: PromptRequest):
220
+ """
221
+ Main endpoint for processing prompts from n8n
222
+ """
223
+ start_time = datetime.now()
224
+
225
+ try:
226
+ logger.info(f"Processing prompt: {request.prompt[:50]}...")
227
+
228
+ # Process the prompt
229
+ result = model_manager.process_prompt(
230
+ prompt=request.prompt,
231
+ content=request.content,
232
+ max_length=request.max_length,
233
+ temperature=request.temperature
234
+ )
235
+
236
+ processing_time = (datetime.now() - start_time).total_seconds()
237
+
238
+ return PromptResponse(
239
+ success=True,
240
+ result=result,
241
+ processing_time=processing_time,
242
+ model_used="text-generation" # You can make this dynamic
243
+ )
244
+
245
+ except Exception as e:
246
+ logger.error(f"Error in process_prompt: {e}")
247
+ return PromptResponse(
248
+ success=False,
249
+ error=str(e),
250
+ processing_time=(datetime.now() - start_time).total_seconds()
251
+ )
252
+
253
+ @app.post("/batch", response_model=List[PromptResponse])
254
+ async def process_batch(request: BatchRequest):
255
+ """
256
+ Process multiple prompts in batch
257
+ """
258
+ responses = []
259
+
260
+ for prompt_req in request.prompts:
261
+ start_time = datetime.now()
262
+
263
+ try:
264
+ result = model_manager.process_prompt(
265
+ prompt=prompt_req.prompt,
266
+ content=prompt_req.content,
267
+ max_length=prompt_req.max_length,
268
+ temperature=prompt_req.temperature
269
+ )
270
+
271
+ responses.append(PromptResponse(
272
+ success=True,
273
+ result=result,
274
+ processing_time=(datetime.now() - start_time).total_seconds()
275
+ ))
276
+
277
+ except Exception as e:
278
+ responses.append(PromptResponse(
279
+ success=False,
280
+ error=str(e),
281
+ processing_time=(datetime.now() - start_time).total_seconds()
282
+ ))
283
+
284
+ return responses
285
+
286
+ # Webhook endpoint (for n8n webhook node)
287
+ @app.post("/webhook")
288
+ async def webhook_endpoint(
289
+ payload: Dict[str, Any],
290
+ x_n8n_signature: Optional[str] = Header(None)
291
+ ):
292
+ """
293
+ Webhook endpoint specifically for n8n
294
+ """
295
+ logger.info(f"Webhook received from n8n: {payload.keys()}")
296
+
297
+ # Extract prompt from n8n payload
298
+ prompt = payload.get("prompt") or payload.get("text") or payload.get("message")
299
+ content = payload.get("content") or payload.get("data")
300
+
301
+ if not prompt:
302
+ raise HTTPException(status_code=400, detail="No prompt provided in payload")
303
+
304
+ # Process the prompt
305
+ result = model_manager.process_prompt(prompt, content)
306
+
307
+ # Return in n8n-friendly format
308
+ return {
309
+ "success": True,
310
+ "response": result,
311
+ "timestamp": datetime.now().isoformat(),
312
+ "webhook_id": payload.get("webhookId"),
313
+ "workflow_id": payload.get("workflowId")
314
+ }
315
+
316
+ # Async task endpoint
317
+ @app.post("/async")
318
+ async def create_async_task(request: PromptRequest):
319
+ """
320
+ Create an async task (returns task ID immediately)
321
+ """
322
+ task_id = f"task_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
323
+
324
+ # In production, you'd queue this task
325
+ return {
326
+ "task_id": task_id,
327
+ "status": "queued",
328
+ "message": "Task created successfully"
329
+ }
330
+
331
+ @app.get("/task/{task_id}")
332
+ async def get_task_status(task_id: str):
333
+ """
334
+ Check status of async task
335
+ """
336
+ return {
337
+ "task_id": task_id,
338
+ "status": "completed", # Mock response
339
+ "result": "This is a mock result for async task"
340
+ }
341
+
342
+ # For Hugging Face Spaces
343
+ @app.get("/hf_space")
344
+ async def hf_space_endpoint(prompt: str = None, content: str = None):
345
+ """
346
+ Simple endpoint for Hugging Face Spaces demo
347
+ """
348
+ if not prompt:
349
+ return {"error": "Please provide a prompt parameter"}
350
+
351
+ result = model_manager.process_prompt(prompt, content)
352
+
353
+ return {
354
+ "prompt": prompt,
355
+ "response": result,
356
+ "content_length": len(content) if content else 0
357
+ }
358
+
359
+ if __name__ == "__main__":
360
+ port = int(os.getenv("PORT", 8000))
361
+ uvicorn.run(app, host="0.0.0.0", port=port)