yukee1992 commited on
Commit
a434ebb
·
verified ·
1 Parent(s): 66dcfeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +348 -412
app.py CHANGED
@@ -1,461 +1,397 @@
1
- # app.py - UPDATED WITH NEW HUGGINGFACE API
2
- from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel
4
- from typing import Optional, List, Dict, Any
5
  import os
6
- import json
7
- import requests
 
8
  import logging
9
- from datetime import datetime
10
- import time
 
 
 
 
 
 
 
11
 
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
 
 
 
 
15
 
16
- app = FastAPI(
17
- title="AI Summarization API",
18
- description="Free AI for Summarization and Viral Stories",
19
- version="2.0.0"
20
  )
 
21
 
22
- # Get HuggingFace token from environment
23
- HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "")
24
- # NEW HUGGINGFACE API ENDPOINT
25
- HF_API_URL = "https://router.huggingface.co/huggingface"
26
-
27
- # WORKING MODELS with NEW API format
28
- MODELS = {
29
- "qwen": "Qwen/Qwen2.5-7B-Instruct", # Best for Chinese-English
30
- "mistral": "mistralai/Mistral-7B-Instruct-v0.3",
31
- "llama": "meta-llama/Llama-3.2-3B-Instruct",
32
- "phi": "microsoft/phi-2", # Lightweight
33
- "gemma": "google/gemma-2-9b-it", # Usually available
34
- "zephyr": "HuggingFaceH4/zephyr-7b-beta"
35
- }
36
-
37
- class SummarizeRequest(BaseModel):
38
- content: str
39
- language: Optional[str] = "chinese"
40
- max_length: Optional[int] = 150
41
- min_length: Optional[int] = 50
42
- model: Optional[str] = "qwen"
43
-
44
- class StoryRequest(BaseModel):
45
- topic: str
46
- platform: Optional[str] = "wechat"
47
- language: Optional[str] = "chinese"
48
- model: Optional[str] = "qwen"
49
 
50
- class ChatRequest(BaseModel):
51
- prompt: str
52
- model: Optional[str] = "qwen"
53
- max_tokens: Optional[int] = 500
54
 
55
- def call_huggingface_api(model: str, prompt: str, max_tokens: int = 500) -> str:
56
- """Call HuggingFace NEW Router API"""
57
-
58
- if not HF_TOKEN:
59
- raise Exception("HUGGINGFACE_TOKEN not configured. Please set it in environment variables.")
60
-
61
- model_name = MODELS.get(model, MODELS["qwen"])
62
-
63
- headers = {
64
- "Authorization": f"Bearer {HF_TOKEN}",
65
- "Content-Type": "application/json"
66
- }
67
-
68
- # Format prompt based on model
69
- if "qwen" in model_name.lower():
70
- formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
71
- elif "mistral" in model_name.lower():
72
- formatted_prompt = f"<s>[INST] {prompt} [/INST]"
73
- elif "llama" in model_name.lower():
74
- formatted_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
75
- else:
76
- formatted_prompt = prompt
77
-
78
- payload = {
79
- "model": model_name,
80
- "inputs": formatted_prompt,
81
- "parameters": {
82
- "max_new_tokens": max_tokens,
83
- "temperature": 0.7,
84
- "do_sample": True,
85
- "return_full_text": False
86
- }
87
- }
88
-
89
- try:
90
- logger.info(f"📤 Calling HuggingFace Router API: {model_name}")
91
-
92
- response = requests.post(
93
- f"{HF_API_URL}/models/v1/{model_name}",
94
- headers=headers,
95
- json=payload,
96
- timeout=60 # Increased timeout
97
- )
98
 
99
- if response.status_code == 200:
100
- result = response.json()
101
-
102
- # Parse response based on format
103
- if isinstance(result, list):
104
- if len(result) > 0:
105
- if isinstance(result[0], dict):
106
- if "generated_text" in result[0]:
107
- text = result[0]["generated_text"]
108
- else:
109
- text = str(result[0])
110
- else:
111
- text = str(result[0])
112
- else:
113
- text = "No response generated"
114
- elif isinstance(result, dict):
115
- if "generated_text" in result:
116
- text = result["generated_text"]
117
- elif "choices" in result: # Chat format
118
- if len(result["choices"]) > 0:
119
- text = result["choices"][0].get("message", {}).get("content", "")
120
- else:
121
- text = "No choices available"
122
- else:
123
- text = str(result)
124
- else:
125
- text = str(result)
126
 
127
- # Clean up the response
128
- if formatted_prompt in text:
129
- text = text.replace(formatted_prompt, "").strip()
 
130
 
131
- return text
 
 
 
 
 
132
 
133
- elif response.status_code == 503:
134
- # Model is loading
135
- error_data = response.json()
136
- error_msg = error_data.get("error", "Model is loading")
137
 
138
- # Check if it's loading or unavailable
139
- if "loading" in error_msg.lower():
140
- logger.info("⏳ Model is loading, waiting 45 seconds...")
141
- time.sleep(45)
142
-
143
- # Try one more time with longer timeout
144
- response = requests.post(
145
- f"{HF_API_URL}/models/v1/{model_name}",
146
- headers=headers,
147
- json=payload,
148
- timeout=90
149
- )
150
-
151
- if response.status_code == 200:
152
- return call_huggingface_api(model, prompt, max_tokens)
153
- else:
154
- raise Exception(f"Model still loading after wait: {error_msg}")
155
- else:
156
- raise Exception(f"Model unavailable: {error_msg}")
157
-
158
- elif response.status_code == 429:
159
- # Rate limit
160
- raise Exception("Rate limit exceeded. Please wait a moment and try again.")
161
-
162
- else:
163
- error_msg = response.text[:500]
164
- logger.error(f"API Error {response.status_code}: {error_msg}")
165
 
166
- # Try with a simpler model
167
- if model != "phi":
168
- logger.info(f"🔄 Trying with phi model instead...")
169
- return call_huggingface_api("phi", prompt, max_tokens)
170
- else:
171
- raise Exception(f"API Error {response.status_code}: {error_msg}")
172
-
173
- except requests.exceptions.Timeout:
174
- # Try with a smaller model
175
- if model != "phi":
176
- logger.warning("⏰ Timeout, trying phi model...")
177
- return call_huggingface_api("phi", prompt, max_tokens)
178
- else:
179
- raise Exception("Request timeout. Please try again with shorter text.")
180
- except Exception as e:
181
- raise Exception(f"API call failed: {str(e)}")
182
 
183
- @app.get("/")
184
- async def root():
185
- return {
186
- "status": "online",
187
- "service": "AI Summarization API (v2.0)",
188
- "models": list(MODELS.keys()),
189
- "recommended_model": "qwen (for Chinese)",
190
- "endpoints": {
191
- "/health": "GET - Health check",
192
- "/test": "GET - Test API",
193
- "/test/{model}": "GET - Test specific model",
194
- "/summarize": "POST - Summarize text",
195
- "/create_story": "POST - Create viral story",
196
- "/chat": "POST - General chat"
197
- },
198
- "note": "Using HuggingFace Router API"
199
- }
200
 
201
- @app.get("/health")
202
- async def health():
203
- hf_configured = bool(HF_TOKEN)
204
- return {
205
- "status": "healthy",
206
- "huggingface_configured": hf_configured,
207
- "available_models": len(MODELS),
208
- "timestamp": datetime.now().isoformat(),
209
- "api_version": "router.huggingface.co"
210
- }
211
 
212
- @app.get("/test")
213
- async def test():
214
- """Test with default model (qwen)"""
215
- return await test_model("qwen")
 
216
 
217
- @app.get("/test/{model_name}")
218
- async def test_model(model_name: str):
219
- """Test specific model"""
220
-
221
- if not HF_TOKEN:
222
- return {
223
- "success": False,
224
- "error": "HUGGINGFACE_TOKEN not configured",
225
- "help": "Get free token from https://huggingface.co/settings/tokens"
226
- }
227
-
228
- if model_name not in MODELS:
229
- return {
230
- "success": False,
231
- "error": f"Model '{model_name}' not available",
232
- "available_models": list(MODELS.keys())
233
- }
234
 
235
- # Simple test prompt
236
- test_prompt = "请用中文简单介绍人工智能,不超过100字。"
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  try:
239
- start_time = time.time()
240
- response = call_huggingface_api(model_name, test_prompt, 100)
241
- processing_time = time.time() - start_time
242
-
243
- return {
244
- "success": True,
245
- "model": model_name,
246
- "model_full_name": MODELS[model_name],
247
- "response": response,
248
- "response_preview": response[:200] + "..." if len(response) > 200 else response,
249
- "length": len(response),
250
- "processing_time_seconds": round(processing_time, 2)
251
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  except Exception as e:
253
- return {
254
- "success": False,
255
- "model": model_name,
256
- "error": str(e),
257
- "help": "Try a different model or check token permissions"
 
 
 
258
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- @app.post("/summarize")
261
- async def summarize(request: SummarizeRequest):
262
- """Summarize text with AI"""
263
-
264
- start_time = time.time()
265
-
266
- if not request.content or len(request.content.strip()) < 10:
267
- raise HTTPException(status_code=400, detail="Content is too short (min 10 characters)")
268
-
269
- # Create prompt based on language
270
- if request.language.lower() in ["chinese", "zh", "cn"]:
271
- prompt = f"""请用中文总结以下内容,提取3-5个关键要点,保持简洁:
272
-
273
- {request.content[:2000]} # Limit content length
274
-
275
- 总结:"""
276
- else:
277
- prompt = f"""Please summarize the following content in English, extract 3-5 key points:
278
-
279
- {request.content[:2000]}
280
-
281
- Summary:"""
282
 
 
 
 
 
 
 
 
283
  try:
284
- # Limit content length to avoid timeout
285
- content = request.content[:2000] if len(request.content) > 2000 else request.content
286
-
287
- # Call the AI
288
- summary = call_huggingface_api(
289
- model=request.model,
290
- prompt=prompt,
291
- max_tokens=min(request.max_length, 300) # Limit tokens
292
- )
293
 
294
- processing_time = time.time() - start_time
295
-
296
- return {
297
- "success": True,
298
- "summary": summary.strip(),
299
- "model": request.model,
300
- "model_full_name": MODELS.get(request.model, MODELS["qwen"]),
301
- "original_length": len(request.content),
302
- "summary_length": len(summary),
303
- "processing_time_seconds": round(processing_time, 2),
304
- "compression_ratio": f"{len(summary)/max(len(content), 1)*100:.1f}%"
 
 
 
 
 
305
  }
306
 
307
- except Exception as e:
308
- raise HTTPException(
309
- status_code=500,
310
- detail={
311
- "error": f"Summarization failed: {str(e)}",
312
- "suggestion": "Try with model='phi' for faster response"
313
- }
314
  )
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- @app.post("/create_story")
317
- async def create_story(request: StoryRequest):
318
- """Create viral story"""
319
-
320
- start_time = time.time()
321
-
322
- # Create prompt
323
- if request.language.lower() in ["chinese", "zh", "cn"]:
324
- prompt = f"""创作一个关于"{request.topic}"的病毒式故事:
325
-
326
- 要求:
327
- 1. 提供3个吸引人的标题
328
- 2. 故事简短有力(300字内)
329
- 3. 适合{request.platform}平台
330
- 4. 包含传播分析
331
- 5. 添加相关标签
332
-
333
- 请开始:"""
334
- else:
335
- prompt = f"""Create a viral story about "{request.topic}" for {request.platform}:
336
-
337
- Requirements:
338
- 1. Provide 3 catchy titles
339
- 2. Keep story short (under 300 words)
340
- 3. Include virality analysis
341
- 4. Add relevant hashtags
342
-
343
- Start:"""
344
 
 
 
 
 
 
 
345
  try:
346
- story = call_huggingface_api(
347
- model=request.model,
348
- prompt=prompt,
349
- max_tokens=600
350
- )
351
 
352
- processing_time = time.time() - start_time
 
 
 
 
353
 
354
- return {
355
- "success": True,
356
- "story": story.strip(),
357
- "model": request.model,
358
- "model_full_name": MODELS.get(request.model, MODELS["qwen"]),
359
- "topic": request.topic,
360
- "platform": request.platform,
361
- "processing_time_seconds": round(processing_time, 2)
362
- }
 
 
363
 
364
  except Exception as e:
365
- raise HTTPException(
366
- status_code=500,
367
- detail={
368
- "error": f"Story creation failed: {str(e)}",
369
- "suggestion": "Try with model='phi' or shorten your topic"
370
- }
371
- )
372
 
373
- @app.post("/chat")
374
- async def chat(request: ChatRequest):
375
- """General chat endpoint"""
 
 
376
 
377
- start_time = time.time()
 
 
 
 
 
 
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  try:
380
- response = call_huggingface_api(
381
- model=request.model,
382
- prompt=request.prompt,
383
- max_tokens=min(request.max_tokens, 1000)
384
- )
385
 
386
- processing_time = time.time() - start_time
 
387
 
388
- return {
389
- "success": True,
390
- "response": response.strip(),
391
- "model": request.model,
392
- "model_full_name": MODELS.get(request.model, MODELS["qwen"]),
393
- "processing_time_seconds": round(processing_time, 2)
394
- }
 
 
 
 
395
 
396
  except Exception as e:
397
- raise HTTPException(
398
- status_code=500,
399
- detail=f"Chat failed: {str(e)}"
400
- )
401
 
402
- # Simple fallback for testing
403
- @app.post("/summarize_simple")
404
- async def summarize_simple(request: Dict[str, Any]):
405
- """Simple summarization without complex AI"""
406
-
407
- content = request.get("content", "")
408
-
409
- if not content or len(content) < 20:
410
- return {
411
- "success": True,
412
- "summary": "内容太短,无法总结。",
413
- "model": "simple",
414
- "processing_time_seconds": 0.01
415
- }
416
-
417
- # Simple rule-based summarization for Chinese
418
- if any(char in content for char in ["", "", "?"]):
419
- sentences = []
420
- for char in ["。", "!", "?"]:
421
- if char in content:
422
- parts = content.split(char)
423
- sentences.extend([p + char for p in parts[:-1]])
424
-
425
- if sentences:
426
- summary = "。".join(sentences[:3]) + "。"
427
- else:
428
- summary = content[:100] + "..."
429
- else:
430
- summary = content[:100] + "..."
431
-
432
- return {
433
- "success": True,
434
- "summary": summary,
435
- "model": "simple_fallback",
436
- "processing_time_seconds": 0.01
437
- }
438
 
439
  if __name__ == "__main__":
440
- import uvicorn
441
-
442
- port = int(os.getenv("PORT", 7860))
443
-
444
- logger.info("=" * 60)
445
- logger.info("🚀 AI Summarization API (v2.0)")
446
- logger.info(f"🔑 HuggingFace Token: {'✅ Configured' if HF_TOKEN else '❌ NOT CONFIGURED'}")
447
- logger.info(f"🌐 Using API: router.huggingface.co")
448
- logger.info(f"🤖 Available models: {list(MODELS.keys())}")
449
- logger.info(f"⭐ Recommended for Chinese: qwen")
450
- logger.info(f"⚡ Lightweight backup: phi")
451
- logger.info("=" * 60)
452
-
453
- if not HF_TOKEN:
454
- logger.error("❌ ERROR: HUGGINGFACE_TOKEN is required!")
455
- logger.info("ℹ️ Steps to fix:")
456
- logger.info("1. Go to: https://huggingface.co/settings/tokens")
457
- logger.info("2. Create new token with 'read' access")
458
- logger.info("3. Add to Space: Settings → Repository secrets")
459
- logger.info("4. Restart the Space")
460
-
461
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
1
  import os
2
+ import uuid
3
+ import httpx
4
+ import torch
5
  import logging
6
+ import json
7
+ import asyncio
8
+ from typing import Dict, Optional, List, Union
9
+ from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends
10
+ from fastapi.responses import JSONResponse
11
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import uvicorn
14
+ from contextlib import asynccontextmanager
15
 
16
+ # Configuration
17
+ MODEL_ID = "google/gemma-1.1-2b-it"
18
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
19
+ API_KEY = os.getenv("API_KEY", "default-key-123")
20
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "450"))
21
+ DEVICE = os.getenv("DEVICE", "cpu")
22
+ PORT = int(os.getenv("PORT", "7860"))
23
 
24
+ # Setup logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
28
  )
29
+ logger = logging.getLogger(__name__)
30
 
31
+ # Security
32
+ security = HTTPBearer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Job storage
35
+ jobs: Dict[str, dict] = {}
 
 
36
 
37
+ class AIGenerator:
38
+ def __init__(self):
39
+ self.tokenizer = None
40
+ self.model = None
41
+ self.loaded = False
42
+ self.load_error = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def load_model(self):
45
+ """Load the AI model"""
46
+ if self.loaded:
47
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ logger.info("Loading model...")
50
+ try:
51
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
52
+ logger.info("✅ Tokenizer loaded")
53
 
54
+ self.model = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_ID,
56
+ torch_dtype=torch.float32,
57
+ token=HF_TOKEN,
58
+ device_map=None
59
+ )
60
 
61
+ self.model = self.model.to(DEVICE)
62
+ self.model.eval()
 
 
63
 
64
+ self.loaded = True
65
+ logger.info(" Model loaded successfully")
66
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ except Exception as e:
69
+ self.load_error = str(e)
70
+ logger.error(f" Model loading failed: {str(e)}")
71
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Global generator instance
74
+ generator = AIGenerator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
77
+ """Verify API key"""
78
+ if credentials.credentials != API_KEY:
79
+ raise HTTPException(status_code=401, detail="Invalid API key")
80
+ return True
 
 
 
 
 
81
 
82
+ @asynccontextmanager
83
+ async def lifespan(app: FastAPI):
84
+ """Lifespan manager for FastAPI"""
85
+ logger.info("🚀 API Server starting up...")
86
+ yield
87
 
88
+ app = FastAPI(lifespan=lifespan)
89
+
90
+ def generate_text(prompt: str, max_tokens: int = None) -> str:
91
+ """
92
+ General function to generate text based on prompt
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ Args:
95
+ prompt: The prompt/instructions to send to the AI model
96
+ max_tokens: Maximum tokens to generate (defaults to MAX_TOKENS env var)
97
 
98
+ Returns:
99
+ Generated text from the AI model
100
+ """
101
+ try:
102
+ if not generator.loaded:
103
+ if not generator.load_model():
104
+ raise Exception(f"Model failed to load: {generator.load_error}")
105
+
106
+ logger.info(f"📝 Generating text with prompt (first 200 chars): {prompt[:200]}...")
107
+
108
+ # Tokenize the prompt
109
+ inputs = generator.tokenizer(
110
+ prompt,
111
+ return_tensors="pt",
112
+ truncation=True,
113
+ max_length=512
114
+ )
115
+
116
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
117
+
118
+ # Generate text
119
+ with torch.no_grad():
120
+ outputs = generator.model.generate(
121
+ **inputs,
122
+ max_new_tokens=max_tokens or MAX_TOKENS,
123
+ do_sample=True,
124
+ top_p=0.9,
125
+ temperature=0.8,
126
+ pad_token_id=generator.tokenizer.eos_token_id,
127
+ repetition_penalty=1.1
128
+ )
129
+
130
+ # Decode the generated text
131
+ full_output = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
132
+
133
+ # Remove the prompt from the output if it's included
134
+ generated_text = full_output
135
+ if prompt in generated_text:
136
+ generated_text = generated_text.replace(prompt, "").strip()
137
+
138
+ logger.info(f"✅ Generated {len(generated_text)} characters")
139
+ return generated_text
140
+
141
+ except Exception as e:
142
+ logger.error(f"❌ Text generation failed: {str(e)}")
143
+ raise
144
+
145
+ async def process_job(job_id: str, prompt: str, callback_url: str = None):
146
+ """Background task to process job with custom prompt"""
147
  try:
148
+ logger.info(f"🎯 Processing job {job_id}")
149
+
150
+ # Generate text based on the provided prompt
151
+ generated_text = generate_text(prompt)
152
+
153
+ # Store job results
154
+ jobs[job_id] = {
155
+ "status": "complete",
156
+ "result": generated_text,
157
+ "prompt": prompt,
158
+ "text_length": len(generated_text),
159
+ "model": MODEL_ID
160
  }
161
+
162
+ logger.info(f"✅ Completed job {job_id}")
163
+
164
+ # Send webhook callback if URL provided
165
+ if callback_url:
166
+ try:
167
+ webhook_data = {
168
+ "job_id": job_id,
169
+ "status": "complete",
170
+ "result": generated_text,
171
+ "prompt": prompt,
172
+ "text_length": len(generated_text),
173
+ "model": MODEL_ID
174
+ }
175
+
176
+ logger.info(f"📨 Sending webhook to: {callback_url}")
177
+
178
+ async with httpx.AsyncClient(timeout=30.0) as client:
179
+ response = await client.post(
180
+ callback_url,
181
+ json=webhook_data,
182
+ headers={"Content-Type": "application/json"}
183
+ )
184
+
185
+ if response.status_code >= 200 and response.status_code < 300:
186
+ logger.info(f"✅ Webhook delivered successfully: {response.status_code}")
187
+ else:
188
+ logger.warning(f"⚠️ Webhook returned non-2xx status: {response.status_code}")
189
+
190
+ except Exception as e:
191
+ logger.error(f"❌ Webhook failed: {str(e)}")
192
+
193
  except Exception as e:
194
+ error_msg = f"Job failed: {str(e)}"
195
+ logger.error(f"❌ Job {job_id} failed: {error_msg}")
196
+
197
+ # Store failure information
198
+ jobs[job_id] = {
199
+ "status": "failed",
200
+ "error": error_msg,
201
+ "prompt": prompt
202
  }
203
+
204
+ # Send failure webhook if callback URL exists
205
+ if callback_url:
206
+ try:
207
+ async with httpx.AsyncClient(timeout=10.0) as client:
208
+ await client.post(
209
+ callback_url,
210
+ json={
211
+ "job_id": job_id,
212
+ "status": "failed",
213
+ "error": error_msg,
214
+ "prompt": prompt
215
+ },
216
+ headers={"Content-Type": "application/json"}
217
+ )
218
+ except Exception as e:
219
+ logger.error(f"Failed to send error webhook: {e}")
220
 
221
+ @app.post("/api/generate")
222
+ async def generate(
223
+ request: Request,
224
+ background_tasks: BackgroundTasks,
225
+ auth: bool = Depends(verify_api_key)
226
+ ):
227
+ """
228
+ Endpoint to generate text with custom prompt instructions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ Expected JSON payload:
231
+ {
232
+ "prompt": "Your instructions here",
233
+ "max_tokens": 450, # optional
234
+ "callback_url": "https://your-webhook.url" # optional
235
+ }
236
+ """
237
  try:
238
+ data = await request.json()
239
+ job_id = str(uuid.uuid4())
 
 
 
 
 
 
 
240
 
241
+ # Validate input
242
+ if not data.get("prompt"):
243
+ raise HTTPException(status_code=400, detail="Prompt is required")
244
+
245
+ prompt = data["prompt"]
246
+ max_tokens = data.get("max_tokens")
247
+ callback_url = data.get("callback_url")
248
+
249
+ logger.info(f"📥 Received job {job_id} with prompt length: {len(prompt)}")
250
+
251
+ # Store initial job data
252
+ jobs[job_id] = {
253
+ "status": "processing",
254
+ "callback_url": callback_url,
255
+ "prompt": prompt,
256
+ "max_tokens": max_tokens
257
  }
258
 
259
+ # Process job in background
260
+ background_tasks.add_task(
261
+ process_job,
262
+ job_id,
263
+ prompt,
264
+ callback_url
 
265
  )
266
+
267
+ return JSONResponse({
268
+ "job_id": job_id,
269
+ "status": "queued",
270
+ "message": "Text generation started",
271
+ "model": MODEL_ID,
272
+ "estimated_time": "30-60 seconds"
273
+ })
274
+
275
+ except Exception as e:
276
+ logger.error(f"❌ Generation request error: {str(e)}")
277
+ raise HTTPException(status_code=400, detail=str(e))
278
 
279
+ @app.post("/api/generate-sync")
280
+ async def generate_sync(
281
+ request: Request,
282
+ auth: bool = Depends(verify_api_key)
283
+ ):
284
+ """
285
+ Synchronous endpoint for immediate generation (for smaller requests)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ Expected JSON payload:
288
+ {
289
+ "prompt": "Your instructions here",
290
+ "max_tokens": 450 # optional
291
+ }
292
+ """
293
  try:
294
+ data = await request.json()
 
 
 
 
295
 
296
+ if not data.get("prompt"):
297
+ raise HTTPException(status_code=400, detail="Prompt is required")
298
+
299
+ prompt = data["prompt"]
300
+ max_tokens = data.get("max_tokens")
301
 
302
+ logger.info(f"📝 Synchronous generation request with prompt length: {len(prompt)}")
303
+
304
+ # Generate text synchronously
305
+ generated_text = generate_text(prompt, max_tokens)
306
+
307
+ return JSONResponse({
308
+ "status": "success",
309
+ "result": generated_text,
310
+ "text_length": len(generated_text),
311
+ "model": MODEL_ID
312
+ })
313
 
314
  except Exception as e:
315
+ logger.error(f"❌ Synchronous generation failed: {str(e)}")
316
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
317
 
318
+ @app.get("/api/status/{job_id}")
319
+ async def get_status(job_id: str, auth: bool = Depends(verify_api_key)):
320
+ """Check job status"""
321
+ if job_id not in jobs:
322
+ raise HTTPException(status_code=404, detail="Job not found")
323
 
324
+ return JSONResponse(jobs[job_id])
325
+
326
+ @app.get("/health")
327
+ async def health_check():
328
+ """Health check endpoint"""
329
+ completed_jobs = [job for job in jobs.values() if job.get("status") == "complete"]
330
+ avg_length = sum(job.get("text_length", 0) for job in completed_jobs) / max(1, len(completed_jobs))
331
 
332
+ return JSONResponse({
333
+ "status": "healthy",
334
+ "model_loaded": generator.loaded,
335
+ "model_id": MODEL_ID,
336
+ "total_jobs": len(jobs),
337
+ "completed_jobs": len(completed_jobs),
338
+ "failed_jobs": sum(1 for job in jobs.values() if job.get("status") == "failed"),
339
+ "average_text_length": round(avg_length, 2)
340
+ })
341
+
342
+ @app.post("/api/test")
343
+ async def test_generation(
344
+ request: Request,
345
+ auth: bool = Depends(verify_api_key)
346
+ ):
347
+ """Test endpoint with custom prompt"""
348
  try:
349
+ if not generator.loaded:
350
+ if not generator.load_model():
351
+ return JSONResponse({"status": "error", "error": "Model failed to load"})
 
 
352
 
353
+ data = await request.json()
354
+ test_prompt = data.get("prompt", "Write a short story about AI in 100 words.")
355
 
356
+ logger.info(f"🧪 Testing generation with prompt: {test_prompt[:100]}...")
357
+
358
+ generated_text = generate_text(test_prompt, max_tokens=200)
359
+
360
+ return JSONResponse({
361
+ "status": "success",
362
+ "prompt": test_prompt,
363
+ "result": generated_text,
364
+ "text_length": len(generated_text),
365
+ "model": MODEL_ID
366
+ })
367
 
368
  except Exception as e:
369
+ logger.error(f"❌ Test generation failed: {str(e)}")
370
+ return JSONResponse({"status": "error", "error": str(e)}, status_code=500)
 
 
371
 
372
+ @app.get("/")
373
+ async def root():
374
+ """Root endpoint"""
375
+ return JSONResponse({
376
+ "message": "AI Text Generation API",
377
+ "version": "3.0",
378
+ "model": MODEL_ID,
379
+ "features": "General purpose text generation with custom prompts",
380
+ "endpoints": {
381
+ "generate_async": "POST /api/generate (with 'prompt' field)",
382
+ "generate_sync": "POST /api/generate-sync (with 'prompt' field)",
383
+ "check_status": "GET /api/status/{job_id}",
384
+ "health": "GET /health",
385
+ "test": "POST /api/test (with optional 'prompt' field)"
386
+ },
387
+ "usage": "Send POST request with {'prompt': 'your instructions'}",
388
+ "status": "operational"
389
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  if __name__ == "__main__":
392
+ uvicorn.run(
393
+ app,
394
+ host="0.0.0.0",
395
+ port=PORT,
396
+ log_level="info"
397
+ )