yukee1992 commited on
Commit
944a2f2
Β·
verified Β·
1 Parent(s): 94d10b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -42
app.py CHANGED
@@ -1,62 +1,340 @@
1
  import os
2
- import sys
 
 
3
  import logging
4
- from fastapi import FastAPI
 
 
 
5
  from fastapi.responses import JSONResponse
 
 
6
  import uvicorn
 
7
 
8
- # Setup logging to see what's happening
9
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
 
 
 
 
 
10
  logger = logging.getLogger(__name__)
11
 
12
- app = FastAPI()
 
13
 
14
- # Debug: Print ALL environment variables
15
- logger.info("=" * 50)
16
- logger.info("STARTING APP - CHECKING ENVIRONMENT")
17
- logger.info("=" * 50)
18
 
19
- for key in sorted(os.environ.keys()):
20
- value = os.environ.get(key)
21
- if any(term in key.lower() for term in ['token', 'key', 'secret', 'pass', 'auth']):
22
- # Hide sensitive values but show they exist
23
- logger.info(f"ENV {key}: {'SET' if value else 'NOT SET'} (hidden for security)")
24
- else:
25
- logger.info(f"ENV {key}: {value}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Explicitly check for HF_TOKEN
28
- HF_TOKEN = os.getenv("HF_TOKEN")
29
- logger.info(f"HF_TOKEN explicitly: {'SET' if HF_TOKEN else 'NOT SET'}")
30
 
31
- # Check other common names
32
- for var_name in ["HF_TOKEN", "HUGGINGFACE_TOKEN", "HUGGING_FACE_HUB_TOKEN"]:
33
- value = os.getenv(var_name)
34
- if value:
35
- logger.info(f"Found in {var_name}: {'SET' if value else 'NOT SET'}")
36
 
37
- @app.get("/")
38
- async def root():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return JSONResponse({
40
- "message": "AI API Debug Version",
41
- "status": "running",
42
- "hf_token_set": bool(os.getenv("HF_TOKEN")),
43
- "all_env_vars": {k: 'SET' if v else 'NOT SET' for k, v in os.environ.items() if 'TOKEN' in k or 'KEY' in k}
 
44
  })
45
 
46
- @app.get("/test-env")
47
- async def test_env():
48
- """Test if HF_TOKEN is available"""
49
- hf_token = os.getenv("HF_TOKEN", "")
50
  return JSONResponse({
51
- "hf_token_exists": bool(hf_token),
52
- "hf_token_length": len(hf_token) if hf_token else 0,
53
- "hf_token_preview": f"{hf_token[:10]}..." if hf_token and len(hf_token) > 10 else hf_token,
54
- "all_env_keys": list(os.environ.keys())
 
 
55
  })
56
 
57
- @app.get("/health")
58
- async def health():
59
- return JSONResponse({"status": "healthy"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  if __name__ == "__main__":
62
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
 
 
 
 
1
  import os
2
+ import uuid
3
+ import httpx
4
+ import torch
5
  import logging
6
+ import json
7
+ import asyncio
8
+ from typing import Dict, Optional
9
+ from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends
10
  from fastapi.responses import JSONResponse
11
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
  import uvicorn
14
+ from contextlib import asynccontextmanager
15
 
16
+ # Configuration - NOW WORKING!
17
+ MODEL_ID = "google/gemma-1.1-2b-it"
18
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
19
+ API_KEY = os.getenv("API_KEY", "default-key-123")
20
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "450"))
21
+ DEVICE = os.getenv("DEVICE", "cpu")
22
+ PORT = int(os.getenv("PORT", "7860"))
23
+
24
+ # Setup logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
28
+ )
29
  logger = logging.getLogger(__name__)
30
 
31
+ # Security
32
+ security = HTTPBearer()
33
 
34
+ # Job storage
35
+ jobs: Dict[str, dict] = {}
 
 
36
 
37
+ class AIGenerator:
38
+ def __init__(self):
39
+ self.tokenizer = None
40
+ self.model = None
41
+ self.loaded = False
42
+ self.load_error = None
43
+
44
+ def load_model(self):
45
+ """Load the AI model with authentication"""
46
+ if self.loaded:
47
+ return True
48
+
49
+ logger.info(f"πŸš€ Loading model: {MODEL_ID}")
50
+
51
+ if not HF_TOKEN:
52
+ logger.error("❌ HF_TOKEN is not set!")
53
+ self.load_error = "HF_TOKEN environment variable is not set"
54
+ return False
55
+
56
+ try:
57
+ # Load tokenizer with authentication
58
+ logger.info("πŸ“₯ Loading tokenizer...")
59
+ self.tokenizer = AutoTokenizer.from_pretrained(
60
+ MODEL_ID,
61
+ token=HF_TOKEN # Key change: use 'token' parameter
62
+ )
63
+
64
+ # Set padding token
65
+ if self.tokenizer.pad_token is None:
66
+ self.tokenizer.pad_token = self.tokenizer.eos_token
67
+
68
+ logger.info("βœ… Tokenizer loaded")
69
+
70
+ # Load model with authentication
71
+ logger.info("πŸ“₯ Loading model...")
72
+ self.model = AutoModelForCausalLM.from_pretrained(
73
+ MODEL_ID,
74
+ torch_dtype=torch.float32,
75
+ token=HF_TOKEN, # Key change: use 'token' parameter
76
+ device_map=None
77
+ )
78
+
79
+ # Move to device
80
+ self.model = self.model.to(DEVICE)
81
+ self.model.eval()
82
+
83
+ self.loaded = True
84
+ logger.info("πŸŽ‰ Model loaded successfully!")
85
+ return True
86
+
87
+ except Exception as e:
88
+ self.load_error = str(e)
89
+ logger.error(f"❌ Model loading failed: {str(e)}")
90
+ return False
91
 
92
+ # Global generator instance
93
+ generator = AIGenerator()
 
94
 
95
+ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
96
+ """Verify API key"""
97
+ if credentials.credentials != API_KEY:
98
+ raise HTTPException(status_code=401, detail="Invalid API key")
99
+ return True
100
 
101
+ @asynccontextmanager
102
+ async def lifespan(app: FastAPI):
103
+ """Lifespan manager - preload model on startup"""
104
+ logger.info("πŸš€ Starting AI API Server...")
105
+ logger.info(f"πŸ“Š Config: Model={MODEL_ID}, Device={DEVICE}, MaxTokens={MAX_TOKENS}")
106
+
107
+ # Try to preload model (non-blocking)
108
+ try:
109
+ generator.load_model()
110
+ except Exception as e:
111
+ logger.warning(f"Model preloading failed, will load on first request: {e}")
112
+
113
+ yield
114
+
115
+ app = FastAPI(lifespan=lifespan)
116
+
117
+ def generate_text(prompt: str, max_tokens: int = None) -> str:
118
+ """Generate text based on prompt"""
119
+ try:
120
+ if not generator.loaded:
121
+ if not generator.load_model():
122
+ raise Exception(f"Model failed to load: {generator.load_error}")
123
+
124
+ logger.info(f"πŸ“ Generating text for prompt: '{prompt[:50]}...'")
125
+
126
+ # Tokenize
127
+ inputs = generator.tokenizer(
128
+ prompt,
129
+ return_tensors="pt",
130
+ truncation=True,
131
+ max_length=512
132
+ )
133
+
134
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
135
+
136
+ # Generate
137
+ with torch.no_grad():
138
+ outputs = generator.model.generate(
139
+ **inputs,
140
+ max_new_tokens=max_tokens or MAX_TOKENS,
141
+ do_sample=True,
142
+ top_p=0.9,
143
+ temperature=0.8,
144
+ pad_token_id=generator.tokenizer.pad_token_id,
145
+ repetition_penalty=1.1
146
+ )
147
+
148
+ # Decode
149
+ generated_text = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
150
+
151
+ # Remove prompt if included
152
+ if prompt in generated_text:
153
+ generated_text = generated_text.replace(prompt, "").strip()
154
+
155
+ logger.info(f"βœ… Generated {len(generated_text)} characters")
156
+ return generated_text
157
+
158
+ except Exception as e:
159
+ logger.error(f"❌ Generation failed: {str(e)}")
160
+ raise
161
+
162
+ @app.post("/api/generate-sync")
163
+ async def generate_sync(
164
+ request: Request,
165
+ auth: bool = Depends(verify_api_key)
166
+ ):
167
+ """
168
+ Synchronous text generation
169
+ Body: {"prompt": "your text", "max_tokens": 100}
170
+ """
171
+ try:
172
+ data = await request.json()
173
+
174
+ if not data.get("prompt"):
175
+ raise HTTPException(status_code=400, detail="Prompt is required")
176
+
177
+ prompt = data["prompt"]
178
+ max_tokens = data.get("max_tokens")
179
+
180
+ logger.info(f"πŸ“₯ Sync request: '{prompt[:50]}...'")
181
+
182
+ generated_text = generate_text(prompt, max_tokens)
183
+
184
+ return JSONResponse({
185
+ "status": "success",
186
+ "result": generated_text,
187
+ "prompt": prompt,
188
+ "text_length": len(generated_text),
189
+ "model": MODEL_ID
190
+ })
191
+
192
+ except Exception as e:
193
+ logger.error(f"❌ Sync generation error: {str(e)}")
194
+ raise HTTPException(status_code=500, detail=str(e))
195
+
196
+ @app.post("/api/generate")
197
+ async def generate_async(
198
+ request: Request,
199
+ background_tasks: BackgroundTasks,
200
+ auth: bool = Depends(verify_api_key)
201
+ ):
202
+ """
203
+ Asynchronous text generation (for longer tasks)
204
+ Body: {"prompt": "your text", "max_tokens": 100, "callback_url": "optional"}
205
+ """
206
+ try:
207
+ data = await request.json()
208
+ job_id = str(uuid.uuid4())
209
+
210
+ if not data.get("prompt"):
211
+ raise HTTPException(status_code=400, detail="Prompt is required")
212
+
213
+ prompt = data["prompt"]
214
+ max_tokens = data.get("max_tokens")
215
+ callback_url = data.get("callback_url")
216
+
217
+ logger.info(f"πŸ“₯ Async request {job_id}")
218
+
219
+ jobs[job_id] = {
220
+ "status": "processing",
221
+ "prompt": prompt
222
+ }
223
+
224
+ # Process in background
225
+ background_tasks.add_task(
226
+ process_job_async,
227
+ job_id,
228
+ prompt,
229
+ max_tokens,
230
+ callback_url
231
+ )
232
+
233
+ return JSONResponse({
234
+ "job_id": job_id,
235
+ "status": "queued",
236
+ "message": "Generation started",
237
+ "model": MODEL_ID
238
+ })
239
+
240
+ except Exception as e:
241
+ logger.error(f"❌ Async request error: {str(e)}")
242
+ raise HTTPException(status_code=400, detail=str(e))
243
+
244
+ async def process_job_async(job_id: str, prompt: str, max_tokens: int = None, callback_url: str = None):
245
+ """Background processing for async jobs"""
246
+ try:
247
+ logger.info(f"πŸ”„ Processing async job {job_id}")
248
+
249
+ generated_text = generate_text(prompt, max_tokens)
250
+
251
+ jobs[job_id] = {
252
+ "status": "complete",
253
+ "result": generated_text,
254
+ "prompt": prompt,
255
+ "text_length": len(generated_text)
256
+ }
257
+
258
+ logger.info(f"βœ… Completed async job {job_id}")
259
+
260
+ # Send callback if provided
261
+ if callback_url:
262
+ try:
263
+ async with httpx.AsyncClient(timeout=30.0) as client:
264
+ await client.post(
265
+ callback_url,
266
+ json={
267
+ "job_id": job_id,
268
+ "status": "complete",
269
+ "result": generated_text,
270
+ "prompt": prompt
271
+ }
272
+ )
273
+ except Exception as e:
274
+ logger.error(f"❌ Callback failed: {e}")
275
+
276
+ except Exception as e:
277
+ error_msg = str(e)
278
+ logger.error(f"❌ Async job {job_id} failed: {error_msg}")
279
+ jobs[job_id] = {
280
+ "status": "failed",
281
+ "error": error_msg,
282
+ "prompt": prompt
283
+ }
284
+
285
+ @app.get("/api/status/{job_id}")
286
+ async def get_status(job_id: str, auth: bool = Depends(verify_api_key)):
287
+ """Check job status"""
288
+ if job_id not in jobs:
289
+ raise HTTPException(status_code=404, detail="Job not found")
290
+
291
+ return JSONResponse(jobs[job_id])
292
+
293
+ @app.get("/health")
294
+ async def health_check():
295
+ """Health check endpoint"""
296
  return JSONResponse({
297
+ "status": "healthy",
298
+ "model_loaded": generator.loaded,
299
+ "model": MODEL_ID,
300
+ "device": DEVICE,
301
+ "max_tokens": MAX_TOKENS
302
  })
303
 
304
+ @app.get("/model-info")
305
+ async def model_info():
306
+ """Model information"""
 
307
  return JSONResponse({
308
+ "model": MODEL_ID,
309
+ "loaded": generator.loaded,
310
+ "error": generator.load_error,
311
+ "device": DEVICE,
312
+ "requires_auth": True,
313
+ "token_available": bool(HF_TOKEN)
314
  })
315
 
316
+ @app.get("/")
317
+ async def root():
318
+ """Root endpoint"""
319
+ return JSONResponse({
320
+ "message": "πŸ€– AI Text Generation API",
321
+ "version": "1.0",
322
+ "model": MODEL_ID,
323
+ "status": "operational" if generator.loaded else "model_loading",
324
+ "endpoints": {
325
+ "generate_sync": "POST /api/generate-sync",
326
+ "generate_async": "POST /api/generate",
327
+ "check_status": "GET /api/status/{job_id}",
328
+ "health": "GET /health",
329
+ "model_info": "GET /model-info"
330
+ },
331
+ "usage": 'curl -X POST /api/generate-sync -H "Authorization: Bearer YOUR_KEY" -d \'{"prompt":"Hello"}\''
332
+ })
333
 
334
  if __name__ == "__main__":
335
+ uvicorn.run(
336
+ app,
337
+ host="0.0.0.0",
338
+ port=PORT,
339
+ log_level="info"
340
+ )