Patryk Studzinski commited on
Commit
9222e8a
Β·
1 Parent(s): 1784558

feat: Add main backup and simplified service implementations with API endpoints

Browse files

- Implemented `main_backup.py` for a multi-model description enhancer API with endpoints for health checks, model management, description enhancement, and infill operations.
- Introduced `main_simple.py` for a simplified Bielik LLM service with endpoints for chat and text generation.
- Added request and response models for structured API interactions.
- Created unit tests in `test_simplified.py` to validate API structure, request schemas, and default values without executing model logic.

Files changed (4) hide show
  1. app/main.py +145 -488
  2. app/main_backup.py +548 -0
  3. app/main_simple.py +202 -0
  4. test_simplified.py +132 -0
app/main.py CHANGED
@@ -1,83 +1,87 @@
1
  import os
2
- import time
3
- import asyncio
4
- import importlib
5
  import subprocess
6
  import sys
7
- from fastapi import FastAPI, HTTPException, Depends, Body
8
  from typing import Optional, List
9
- from pydantic import ValidationError
 
10
 
11
- # llama-cpp-python installed at runtime with CUDA support
12
- try:
13
- import llama_cpp
14
- except ImportError:
15
- print("[STARTUP] Installing llama-cpp-python with CUDA...")
16
- env = os.environ.copy()
17
- result = subprocess.run(
18
- [sys.executable, "-m", "pip", "install", "--quiet", "--prefer-binary",
19
- "--index-url", "https://abetlen.github.io/llama-cpp-python/whl/cu121",
20
- "llama-cpp-python[server]>=0.3.16"],
21
- capture_output=True,
22
- text=True
23
- )
24
- if result.returncode != 0:
25
- print("[STARTUP] CUDA wheel failed, trying CPU fallback...")
26
- print(f"[STARTUP] Error details: {result.stderr[:500]}")
27
- subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "llama-cpp-python>=0.3.16"], check=False)
28
- else:
29
- print("[STARTUP] llama-cpp-python with CUDA installed")
30
 
31
  from app.models.registry import registry, MODEL_CONFIG
32
- from fastapi.middleware.cors import CORSMiddleware
33
- from app.schemas.schemas import (
34
- EnhancedDescriptionResponse,
35
- CompareRequest,
36
- CompareResponse,
37
- ModelResult,
38
- ModelInfo,
39
- InfillRequest,
40
- InfillResponse,
41
- InfillResult,
42
- GapFill,
43
- CompareInfillRequest,
44
- CompareInfillResponse,
45
- ModelInfillResult,
46
- )
47
- from app.logic.infill_utils import (
48
- detect_gaps,
49
- parse_infill_response,
50
- apply_fills,
51
- build_fills_dict,
52
- normalize_gaps_to_tagged,
53
- )
54
- from app.auth.placeholder_auth import get_authenticated_user
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  app = FastAPI(
57
- title="Multi-Model Description Enhancer",
58
- description="AI-powered service for enhancing descriptions using multiple LLMs for A/B testing",
59
- version="3.0.0"
60
- )
61
-
62
- # CORS configuration
63
- app.add_middleware(
64
- CORSMiddleware,
65
- allow_origins=[
66
- "http://localhost:5173",
67
- "http://localhost:5174",
68
- os.getenv("FRONTEND_URL", "http://localhost:5173")
69
- ],
70
- allow_credentials=True,
71
- allow_methods=["POST", "GET"],
72
- allow_headers=["*"],
73
  )
74
 
75
  @app.on_event("startup")
76
  async def startup_event():
77
- """
78
- Startup event - models are loaded lazily on first request.
79
- No models are pre-loaded to conserve memory.
80
- """
81
  print("Application started. Models will be loaded lazily on first request.")
82
  print(f"Available models: {registry.get_available_model_names()}")
83
 
@@ -91,458 +95,111 @@ async def startup_event():
91
  except Exception as e:
92
  print(f"GPU check failed: {e}")
93
 
94
- # --- Helper function to load domain logic ---
95
- def get_domain_config(domain: str):
96
- try:
97
- module = importlib.import_module(f"app.domains.{domain}.config")
98
- return module.domain_config
99
- except (ImportError, AttributeError):
100
- raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.")
101
-
102
- # --- API Endpoints ---
103
-
104
- @app.get("/")
105
- async def read_root():
106
- return {"message": "Welcome to the Multi-Model Description Enhancer API! Go to /docs for documentation."}
107
-
108
- @app.get("/health")
109
  async def health_check():
110
- """Check API health and model status."""
111
- models = registry.list_models()
112
- loaded_models = registry.get_loaded_models()
113
- active_model = registry.get_active_model()
114
-
115
  gpu_available = False
116
- gpu_name = "N/A"
117
  try:
118
  import torch
119
  gpu_available = torch.cuda.is_available()
120
- gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A"
121
  except:
122
  pass
123
 
124
- return {
125
- "status": "ok",
126
- "available_models": len(models),
127
- "loaded_models": loaded_models,
128
- "active_local_model": active_model,
129
- "gpu_available": gpu_available,
130
- "gpu_device": gpu_name,
131
- }
132
 
133
- @app.get("/models", response_model=List[ModelInfo])
134
  async def list_models():
135
- """List all available models with their load status."""
136
- return registry.list_models()
 
 
 
 
 
 
 
 
137
 
138
- @app.post("/models/{model_name}/load")
139
- async def load_model(model_name: str):
140
  """
141
- Explicitly load a model into memory.
142
- For local models: unloads any previously loaded local model first.
143
- """
144
- if model_name not in registry.get_available_model_names():
145
- raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}")
146
 
147
- try:
148
- info = await registry.load_model(model_name)
149
- return {"status": "loaded", "model": info}
150
- except Exception as e:
151
- raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
152
-
153
- @app.post("/models/{model_name}/unload")
154
- async def unload_model(model_name: str):
155
  """
156
- Explicitly unload a model from memory to free resources.
157
- """
158
- if model_name not in registry.get_available_model_names():
159
- raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}")
160
-
161
- try:
162
- result = await registry.unload_model(model_name)
163
- return result
164
- except Exception as e:
165
- raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}")
166
-
167
- @app.post("/enhance-description", response_model=EnhancedDescriptionResponse)
168
- async def enhance_description(
169
- domain: str = Body(..., embed=True),
170
- data: dict = Body(..., embed=True),
171
- model: str = Body("bielik-1.5b", embed=True),
172
- user: Optional[dict] = Depends(get_authenticated_user)
173
- ):
174
- """
175
- Generate an enhanced description using a single model.
176
- - **domain**: The name of the domain (e.g., 'cars').
177
- - **data**: A dictionary with the data for the description.
178
- - **model**: Model to use (default: bielik-1.5b)
179
- """
180
- start_time = time.time()
181
-
182
  # Validate model
183
- if model not in registry.get_available_model_names():
184
- raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
185
 
186
- # Load Domain Configuration
187
- domain_config = get_domain_config(domain)
188
- DomainSchema = domain_config["schema"]
189
- create_prompt = domain_config["create_prompt"]
190
-
191
- # Validate Input Data
192
- try:
193
- validated_data = DomainSchema(**data)
194
- except ValidationError as e:
195
- raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}")
196
-
197
- # Prompt Construction
198
- chat_messages = create_prompt(validated_data)
199
-
200
- # Text Generation
201
  try:
202
- llm = await registry.get_model(model)
203
- generated_description = await llm.generate(
204
- chat_messages=chat_messages,
205
- max_new_tokens=150,
206
- temperature=0.75,
207
- top_p=0.9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  )
209
  except Exception as e:
210
- print(f"Error during text generation with {model}: {e}")
211
  raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
212
 
213
- generation_time = time.time() - start_time
214
- user_email = user['email'] if user else "anonymous"
215
-
216
- return EnhancedDescriptionResponse(
217
- description=generated_description,
218
- model_used=MODEL_CONFIG[model]["id"],
219
- generation_time=round(generation_time, 2),
220
- user_email=user_email
221
- )
222
-
223
- @app.post("/compare", response_model=CompareResponse)
224
- async def compare_models(
225
- request: CompareRequest,
226
- user: Optional[dict] = Depends(get_authenticated_user)
227
- ):
228
- """
229
- Compare outputs from multiple models for the same input.
230
- Returns results from all specified models (or all available if not specified).
231
- """
232
- total_start = time.time()
233
-
234
- # Get models to compare
235
- available_models = registry.get_available_model_names()
236
- models_to_use = request.models if request.models else available_models
237
-
238
- # Validate requested models
239
- for model in models_to_use:
240
- if model not in available_models:
241
- raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
242
-
243
- # Load Domain Configuration
244
- domain_config = get_domain_config(request.domain)
245
- DomainSchema = domain_config["schema"]
246
- create_prompt = domain_config["create_prompt"]
247
-
248
- # Validate Input Data
249
- try:
250
- validated_data = DomainSchema(**request.data)
251
- except ValidationError as e:
252
- raise HTTPException(status_code=422, detail=f"Invalid data: {e}")
253
-
254
- # Prompt Construction
255
- chat_messages = create_prompt(validated_data)
256
-
257
- # Generate with each model
258
- results = []
259
-
260
- async def generate_with_model(model_name: str) -> ModelResult:
261
- start_time = time.time()
262
- try:
263
- llm = await registry.get_model(model_name)
264
- output = await llm.generate(
265
- chat_messages=chat_messages,
266
- max_new_tokens=150,
267
- temperature=0.75,
268
- top_p=0.9,
269
- )
270
- return ModelResult(
271
- model=model_name,
272
- output=output,
273
- time=round(time.time() - start_time, 2),
274
- type=MODEL_CONFIG[model_name]["type"],
275
- error=None
276
- )
277
- except Exception as e:
278
- return ModelResult(
279
- model=model_name,
280
- output="",
281
- time=round(time.time() - start_time, 2),
282
- type=MODEL_CONFIG[model_name]["type"],
283
- error=str(e)
284
- )
285
-
286
- # Run all models (sequentially to avoid memory issues)
287
- for model_name in models_to_use:
288
- result = await generate_with_model(model_name)
289
- results.append(result)
290
-
291
- return CompareResponse(
292
- domain=request.domain,
293
- results=results,
294
- total_time=round(time.time() - total_start, 2)
295
- )
296
-
297
- @app.get("/user/me")
298
- async def get_user_info(user: dict = Depends(get_authenticated_user)):
299
- """Get current authenticated user information"""
300
- if not user:
301
- raise HTTPException(status_code=401, detail="Not authenticated")
302
- return {
303
- "user_id": user['user_id'],
304
- "email": user['email'],
305
- "name": user.get('name', 'Unknown')
306
- }
307
-
308
-
309
- # --- Batch Infill Endpoints ---
310
-
311
- @app.post("/infill", response_model=InfillResponse)
312
- async def batch_infill(
313
- request: InfillRequest,
314
- user: Optional[dict] = Depends(get_authenticated_user)
315
- ):
316
  """
317
- Batch gap-filling for ads using a single model.
318
-
319
- Accepts items with [GAP:n] markers or ___ and returns filled text
320
- with per-gap choices and alternatives.
321
 
322
- NOTE: For texts > 6000 chars, consider chunking (not yet implemented).
323
  """
324
- print(f"DEBUG: Hit batch_infill endpoint with model={request.model}", flush=True)
325
- total_start = time.time()
326
-
327
  # Validate model
328
  if request.model not in registry.get_available_model_names():
329
  raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}")
330
 
331
- # Load domain config for infill prompt
332
- domain_config = get_domain_config(request.domain)
333
- if "create_infill_prompt" not in domain_config:
334
- raise HTTPException(
335
- status_code=400,
336
- detail=f"Domain '{request.domain}' does not support infill operations"
337
- )
338
- create_infill_prompt = domain_config["create_infill_prompt"]
339
-
340
- # Process each item
341
- results = []
342
- error_count = 0
343
-
344
- for item in request.items:
345
- result = await process_infill_item(
346
- item=item,
347
- model_name=request.model,
348
- options=request.options,
349
- create_infill_prompt=create_infill_prompt
350
- )
351
- results.append(result)
352
- if result.status == "error":
353
- error_count += 1
354
-
355
- return InfillResponse(
356
- model=request.model,
357
- results=results,
358
- total_time=round(time.time() - total_start, 2),
359
- processed_count=len(results),
360
- error_count=error_count
361
- )
362
-
363
-
364
- @app.post("/compare-infill", response_model=CompareInfillResponse)
365
- async def compare_infill(
366
- request: CompareInfillRequest,
367
- user: Optional[dict] = Depends(get_authenticated_user)
368
- ):
369
- """
370
- Multi-model batch gap-filling comparison for A/B testing.
371
-
372
- Runs the same batch of items through multiple models and returns
373
- per-model results for comparison.
374
- """
375
- total_start = time.time()
376
-
377
- # Get models to compare
378
- available_models = registry.get_available_model_names()
379
- models_to_use = request.models if request.models else available_models
380
-
381
- # Validate requested models
382
- for model in models_to_use:
383
- if model not in available_models:
384
- raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
385
-
386
- # Load domain config
387
- domain_config = get_domain_config(request.domain)
388
- if "create_infill_prompt" not in domain_config:
389
- raise HTTPException(
390
- status_code=400,
391
- detail=f"Domain '{request.domain}' does not support infill operations"
392
- )
393
- create_infill_prompt = domain_config["create_infill_prompt"]
394
-
395
- # Process with each model (sequentially for memory safety)
396
- model_results = []
397
-
398
- for model_name in models_to_use:
399
- model_start = time.time()
400
- results = []
401
- error_count = 0
402
-
403
- for item in request.items:
404
- result = await process_infill_item(
405
- item=item,
406
- model_name=model_name,
407
- options=request.options,
408
- create_infill_prompt=create_infill_prompt
409
- )
410
- results.append(result)
411
- if result.status == "error":
412
- error_count += 1
413
-
414
- model_results.append(ModelInfillResult(
415
- model=model_name,
416
- type=MODEL_CONFIG[model_name]["type"],
417
- results=results,
418
- time=round(time.time() - model_start, 2),
419
- error_count=error_count
420
- ))
421
-
422
- return CompareInfillResponse(
423
- domain=request.domain,
424
- models=model_results,
425
- total_time=round(time.time() - total_start, 2)
426
- )
427
-
428
-
429
- async def process_infill_item(
430
- item,
431
- model_name: str,
432
- options,
433
- create_infill_prompt
434
- ) -> InfillResult:
435
- """
436
- Process a single infill item.
437
-
438
- Returns InfillResult with status, filled_text, and gaps.
439
- """
440
  try:
441
- # Normalize gaps to [GAP:n] format
442
- normalized_text, gaps = normalize_gaps_to_tagged(item.text_with_gaps)
443
-
444
- if not gaps:
445
- # No gaps found, return original text
446
- return InfillResult(
447
- id=item.id,
448
- status="ok",
449
- filled_text=item.text_with_gaps,
450
- gaps=[],
451
- error=None
452
- )
453
-
454
- # Build prompt
455
- if item.custom_messages:
456
- chat_messages = item.custom_messages
457
- use_grammar = False # Custom messages = plain text output expected
458
- else:
459
- chat_messages = create_infill_prompt(normalized_text, options, attributes=item.attributes)
460
- use_grammar = True # Standard prompt = use grammar for structured JSON
461
-
462
- # Generate with optional GBNF grammar constraint
463
- llm = await registry.get_model(model_name)
464
 
465
- grammar_str = None
466
- if use_grammar and hasattr(llm, 'llm') and llm.llm is not None:
467
- # Use model's default grammar (loaded from answers.gbnf) if available
468
- if hasattr(llm, 'default_grammar') and llm.default_grammar:
469
- grammar_str = llm.default_grammar
470
- print(f"DEBUG: Using model's default GBNF grammar", flush=True)
471
- else:
472
- # Fallback to dynamic grammar generation
473
- try:
474
- from app.logic.grammar_utils import get_infill_grammar
475
- grammar_str = get_infill_grammar(len(gaps))
476
- print(f"DEBUG: Using dynamic GBNF grammar for {len(gaps)} gaps", flush=True)
477
- except ImportError:
478
- pass
479
-
480
- raw_output = await llm.generate(
481
- chat_messages=chat_messages,
482
- max_new_tokens=options.max_new_tokens,
483
- temperature=0.3 if use_grammar else options.temperature, # Lower temp with grammar
484
- top_p=0.9,
485
- grammar=grammar_str,
486
  )
487
 
488
- # If custom_messages were provided, the output is plain text (not JSON)
489
- # Just return it directly as a single gap fill
490
- if item.custom_messages:
491
- # Clean up the raw output - strip whitespace, quotes, etc.
492
- choice = raw_output.strip().strip('"\'.,').strip()
493
- return InfillResult(
494
- id=item.id,
495
- status="ok",
496
- filled_text=choice, # The filled text is just the choice itself
497
- gaps=[GapFill(index=1, marker="[GAP:1]", choice=choice, alternatives=[])],
498
- error=None
499
- )
500
-
501
- # Parse JSON from output (standard prompt format)
502
- parsed = parse_infill_response(raw_output)
503
-
504
- if not parsed:
505
- # JSON parsing failed
506
- return InfillResult(
507
- id=item.id,
508
- status="error",
509
- filled_text=None,
510
- gaps=[],
511
- error=f"Failed to parse JSON from model output: {raw_output[:200]}..."
512
- )
513
-
514
- # Extract gaps and build result
515
- gap_fills = []
516
- fills_dict = {}
517
-
518
- for gap_data in parsed.get("gaps", []):
519
- gap_fill = GapFill(
520
- index=gap_data.get("index", 0),
521
- marker=gap_data.get("marker", ""),
522
- choice=gap_data.get("choice", ""),
523
- alternatives=gap_data.get("alternatives", [])
524
- )
525
- gap_fills.append(gap_fill)
526
- fills_dict[gap_fill.index] = gap_fill.choice
527
-
528
- # Get filled text - prefer model's version, fallback to reconstruction
529
- filled_text = parsed.get("filled_text")
530
- if not filled_text and fills_dict:
531
- filled_text = apply_fills(normalized_text, gaps, fills_dict)
532
-
533
- return InfillResult(
534
- id=item.id,
535
- status="ok",
536
- filled_text=filled_text,
537
- gaps=gap_fills,
538
- error=None
539
  )
540
-
541
  except Exception as e:
542
- return InfillResult(
543
- id=item.id,
544
- status="error",
545
- filled_text=None,
546
- gaps=[],
547
- error=str(e)
548
- )
 
 
 
 
1
  import os
 
 
 
2
  import subprocess
3
  import sys
 
4
  from typing import Optional, List
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
 
8
+ # Install llama-cpp-python with CUDA support at runtime
9
+ # Skip during tests/imports
10
+ if os.getenv("SKIP_LLAMA_INSTALL") != "1":
11
+ try:
12
+ import llama_cpp
13
+ except ImportError:
14
+ print("[STARTUP] Installing llama-cpp-python with CUDA...")
15
+ result = subprocess.run(
16
+ [sys.executable, "-m", "pip", "install", "--quiet", "--prefer-binary",
17
+ "--index-url", "https://abetlen.github.io/llama-cpp-python/whl/cu121",
18
+ "llama-cpp-python[server]>=0.3.16"],
19
+ capture_output=True,
20
+ text=True,
21
+ timeout=60
22
+ )
23
+ if result.returncode != 0:
24
+ print("[STARTUP] CUDA wheel failed, trying CPU fallback...")
25
+ subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "llama-cpp-python>=0.3.16"], check=False, timeout=60)
 
26
 
27
  from app.models.registry import registry, MODEL_CONFIG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Request/Response Models
30
+ class Message(BaseModel):
31
+ role: str
32
+ content: str
33
+
34
+ class ChatRequest(BaseModel):
35
+ model: str
36
+ messages: List[Message]
37
+ max_tokens: int = 150
38
+ temperature: float = 0.7
39
+ top_p: float = 0.9
40
+
41
+ class ChatChoice(BaseModel):
42
+ message: Message
43
+ finish_reason: str
44
+
45
+ class ChatResponse(BaseModel):
46
+ model: str
47
+ choices: List[ChatChoice]
48
+ usage: dict
49
+
50
+ class GenerateRequest(BaseModel):
51
+ model: str
52
+ prompt: str
53
+ max_tokens: int = 150
54
+ temperature: float = 0.7
55
+ top_p: float = 0.9
56
+
57
+ class GenerateResponse(BaseModel):
58
+ model: str
59
+ text: str
60
+ tokens_generated: int
61
+
62
+ class ModelInfo(BaseModel):
63
+ name: str
64
+ type: str
65
+ device: str = "unknown"
66
+
67
+ class ModelsResponse(BaseModel):
68
+ models: List[ModelInfo]
69
+
70
+ class HealthResponse(BaseModel):
71
+ status: str
72
+ gpu_available: bool
73
+ models_available: int
74
+
75
+ # Create app
76
  app = FastAPI(
77
+ title="Bielik LLM Service",
78
+ description="Pure inference service for Bielik models with GPU acceleration",
79
+ version="2.0.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
 
82
  @app.on_event("startup")
83
  async def startup_event():
84
+ """Initialize service on startup."""
 
 
 
85
  print("Application started. Models will be loaded lazily on first request.")
86
  print(f"Available models: {registry.get_available_model_names()}")
87
 
 
95
  except Exception as e:
96
  print(f"GPU check failed: {e}")
97
 
98
+ @app.get("/health", response_model=HealthResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  async def health_check():
100
+ """Health check endpoint."""
 
 
 
 
101
  gpu_available = False
 
102
  try:
103
  import torch
104
  gpu_available = torch.cuda.is_available()
 
105
  except:
106
  pass
107
 
108
+ return HealthResponse(
109
+ status="ok",
110
+ gpu_available=gpu_available,
111
+ models_available=len(registry.get_available_model_names())
112
+ )
 
 
 
113
 
114
+ @app.get("/models", response_model=ModelsResponse)
115
  async def list_models():
116
+ """List all available models."""
117
+ models_list = []
118
+ for model_name in registry.get_available_model_names():
119
+ info = registry.get_model_info(model_name)
120
+ models_list.append(ModelInfo(
121
+ name=model_name,
122
+ type=info.get("type", "unknown"),
123
+ device=info.get("device", "unknown")
124
+ ))
125
+ return ModelsResponse(models=models_list)
126
 
127
+ @app.post("/chat", response_model=ChatResponse)
128
+ async def chat_completion(request: ChatRequest):
129
  """
130
+ Chat completion endpoint (OpenAI compatible).
 
 
 
 
131
 
132
+ Accepts a list of messages and returns a completion.
 
 
 
 
 
 
 
133
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Validate model
135
+ if request.model not in registry.get_available_model_names():
136
+ raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}")
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
+ # Load model
140
+ llm = await registry.get_model(request.model)
141
+
142
+ # Convert messages to list of dicts
143
+ messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
144
+
145
+ # Generate
146
+ output = await llm.generate(
147
+ chat_messages=messages,
148
+ max_new_tokens=request.max_tokens,
149
+ temperature=request.temperature,
150
+ top_p=request.top_p,
151
+ )
152
+
153
+ return ChatResponse(
154
+ model=request.model,
155
+ choices=[ChatChoice(
156
+ message=Message(role="assistant", content=output),
157
+ finish_reason="stop"
158
+ )],
159
+ usage={
160
+ "prompt_tokens": sum(len(msg.get("content", "").split()) for msg in messages),
161
+ "completion_tokens": len(output.split())
162
+ }
163
  )
164
  except Exception as e:
 
165
  raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
166
 
167
+ @app.post("/generate", response_model=GenerateResponse)
168
+ async def generate_text(request: GenerateRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  """
170
+ Raw text generation endpoint.
 
 
 
171
 
172
+ Accepts a prompt string and returns generated text.
173
  """
 
 
 
174
  # Validate model
175
  if request.model not in registry.get_available_model_names():
176
  raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}")
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  try:
179
+ # Load model
180
+ llm = await registry.get_model(request.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ # Generate
183
+ output = await llm.generate(
184
+ prompt=request.prompt,
185
+ max_new_tokens=request.max_tokens,
186
+ temperature=request.temperature,
187
+ top_p=request.top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
189
 
190
+ return GenerateResponse(
191
+ model=request.model,
192
+ text=output,
193
+ tokens_generated=len(output.split())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  )
 
195
  except Exception as e:
196
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
197
+
198
+ @app.get("/")
199
+ async def root():
200
+ """Root endpoint."""
201
+ return {
202
+ "message": "Bielik LLM Service",
203
+ "docs": "/docs",
204
+ "endpoints": ["/chat", "/generate", "/models", "/health"]
205
+ }
app/main_backup.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import asyncio
4
+ import importlib
5
+ import subprocess
6
+ import sys
7
+ from fastapi import FastAPI, HTTPException, Depends, Body
8
+ from typing import Optional, List
9
+ from pydantic import ValidationError
10
+
11
+ # llama-cpp-python installed at runtime with CUDA support
12
+ try:
13
+ import llama_cpp
14
+ except ImportError:
15
+ print("[STARTUP] Installing llama-cpp-python with CUDA...")
16
+ env = os.environ.copy()
17
+ result = subprocess.run(
18
+ [sys.executable, "-m", "pip", "install", "--quiet", "--prefer-binary",
19
+ "--index-url", "https://abetlen.github.io/llama-cpp-python/whl/cu121",
20
+ "llama-cpp-python[server]>=0.3.16"],
21
+ capture_output=True,
22
+ text=True
23
+ )
24
+ if result.returncode != 0:
25
+ print("[STARTUP] CUDA wheel failed, trying CPU fallback...")
26
+ print(f"[STARTUP] Error details: {result.stderr[:500]}")
27
+ subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "llama-cpp-python>=0.3.16"], check=False)
28
+ else:
29
+ print("[STARTUP] llama-cpp-python with CUDA installed")
30
+
31
+ from app.models.registry import registry, MODEL_CONFIG
32
+ from fastapi.middleware.cors import CORSMiddleware
33
+ from app.schemas.schemas import (
34
+ EnhancedDescriptionResponse,
35
+ CompareRequest,
36
+ CompareResponse,
37
+ ModelResult,
38
+ ModelInfo,
39
+ InfillRequest,
40
+ InfillResponse,
41
+ InfillResult,
42
+ GapFill,
43
+ CompareInfillRequest,
44
+ CompareInfillResponse,
45
+ ModelInfillResult,
46
+ )
47
+ from app.logic.infill_utils import (
48
+ detect_gaps,
49
+ parse_infill_response,
50
+ apply_fills,
51
+ build_fills_dict,
52
+ normalize_gaps_to_tagged,
53
+ )
54
+ from app.auth.placeholder_auth import get_authenticated_user
55
+
56
+ app = FastAPI(
57
+ title="Multi-Model Description Enhancer",
58
+ description="AI-powered service for enhancing descriptions using multiple LLMs for A/B testing",
59
+ version="3.0.0"
60
+ )
61
+
62
+ # CORS configuration
63
+ app.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=[
66
+ "http://localhost:5173",
67
+ "http://localhost:5174",
68
+ os.getenv("FRONTEND_URL", "http://localhost:5173")
69
+ ],
70
+ allow_credentials=True,
71
+ allow_methods=["POST", "GET"],
72
+ allow_headers=["*"],
73
+ )
74
+
75
+ @app.on_event("startup")
76
+ async def startup_event():
77
+ """
78
+ Startup event - models are loaded lazily on first request.
79
+ No models are pre-loaded to conserve memory.
80
+ """
81
+ print("Application started. Models will be loaded lazily on first request.")
82
+ print(f"Available models: {registry.get_available_model_names()}")
83
+
84
+ try:
85
+ import torch
86
+ gpu_available = torch.cuda.is_available()
87
+ gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A"
88
+ print(f"GPU available: {gpu_available}, Device: {gpu_name}")
89
+ except ImportError:
90
+ print("PyTorch not available for GPU check")
91
+ except Exception as e:
92
+ print(f"GPU check failed: {e}")
93
+
94
+ # --- Helper function to load domain logic ---
95
+ def get_domain_config(domain: str):
96
+ try:
97
+ module = importlib.import_module(f"app.domains.{domain}.config")
98
+ return module.domain_config
99
+ except (ImportError, AttributeError):
100
+ raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.")
101
+
102
+ # --- API Endpoints ---
103
+
104
+ @app.get("/")
105
+ async def read_root():
106
+ return {"message": "Welcome to the Multi-Model Description Enhancer API! Go to /docs for documentation."}
107
+
108
+ @app.get("/health")
109
+ async def health_check():
110
+ """Check API health and model status."""
111
+ models = registry.list_models()
112
+ loaded_models = registry.get_loaded_models()
113
+ active_model = registry.get_active_model()
114
+
115
+ gpu_available = False
116
+ gpu_name = "N/A"
117
+ try:
118
+ import torch
119
+ gpu_available = torch.cuda.is_available()
120
+ gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A"
121
+ except:
122
+ pass
123
+
124
+ return {
125
+ "status": "ok",
126
+ "available_models": len(models),
127
+ "loaded_models": loaded_models,
128
+ "active_local_model": active_model,
129
+ "gpu_available": gpu_available,
130
+ "gpu_device": gpu_name,
131
+ }
132
+
133
+ @app.get("/models", response_model=List[ModelInfo])
134
+ async def list_models():
135
+ """List all available models with their load status."""
136
+ return registry.list_models()
137
+
138
+ @app.post("/models/{model_name}/load")
139
+ async def load_model(model_name: str):
140
+ """
141
+ Explicitly load a model into memory.
142
+ For local models: unloads any previously loaded local model first.
143
+ """
144
+ if model_name not in registry.get_available_model_names():
145
+ raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}")
146
+
147
+ try:
148
+ info = await registry.load_model(model_name)
149
+ return {"status": "loaded", "model": info}
150
+ except Exception as e:
151
+ raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
152
+
153
+ @app.post("/models/{model_name}/unload")
154
+ async def unload_model(model_name: str):
155
+ """
156
+ Explicitly unload a model from memory to free resources.
157
+ """
158
+ if model_name not in registry.get_available_model_names():
159
+ raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}")
160
+
161
+ try:
162
+ result = await registry.unload_model(model_name)
163
+ return result
164
+ except Exception as e:
165
+ raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}")
166
+
167
+ @app.post("/enhance-description", response_model=EnhancedDescriptionResponse)
168
+ async def enhance_description(
169
+ domain: str = Body(..., embed=True),
170
+ data: dict = Body(..., embed=True),
171
+ model: str = Body("bielik-1.5b", embed=True),
172
+ user: Optional[dict] = Depends(get_authenticated_user)
173
+ ):
174
+ """
175
+ Generate an enhanced description using a single model.
176
+ - **domain**: The name of the domain (e.g., 'cars').
177
+ - **data**: A dictionary with the data for the description.
178
+ - **model**: Model to use (default: bielik-1.5b)
179
+ """
180
+ start_time = time.time()
181
+
182
+ # Validate model
183
+ if model not in registry.get_available_model_names():
184
+ raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
185
+
186
+ # Load Domain Configuration
187
+ domain_config = get_domain_config(domain)
188
+ DomainSchema = domain_config["schema"]
189
+ create_prompt = domain_config["create_prompt"]
190
+
191
+ # Validate Input Data
192
+ try:
193
+ validated_data = DomainSchema(**data)
194
+ except ValidationError as e:
195
+ raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}")
196
+
197
+ # Prompt Construction
198
+ chat_messages = create_prompt(validated_data)
199
+
200
+ # Text Generation
201
+ try:
202
+ llm = await registry.get_model(model)
203
+ generated_description = await llm.generate(
204
+ chat_messages=chat_messages,
205
+ max_new_tokens=150,
206
+ temperature=0.75,
207
+ top_p=0.9,
208
+ )
209
+ except Exception as e:
210
+ print(f"Error during text generation with {model}: {e}")
211
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
212
+
213
+ generation_time = time.time() - start_time
214
+ user_email = user['email'] if user else "anonymous"
215
+
216
+ return EnhancedDescriptionResponse(
217
+ description=generated_description,
218
+ model_used=MODEL_CONFIG[model]["id"],
219
+ generation_time=round(generation_time, 2),
220
+ user_email=user_email
221
+ )
222
+
223
+ @app.post("/compare", response_model=CompareResponse)
224
+ async def compare_models(
225
+ request: CompareRequest,
226
+ user: Optional[dict] = Depends(get_authenticated_user)
227
+ ):
228
+ """
229
+ Compare outputs from multiple models for the same input.
230
+ Returns results from all specified models (or all available if not specified).
231
+ """
232
+ total_start = time.time()
233
+
234
+ # Get models to compare
235
+ available_models = registry.get_available_model_names()
236
+ models_to_use = request.models if request.models else available_models
237
+
238
+ # Validate requested models
239
+ for model in models_to_use:
240
+ if model not in available_models:
241
+ raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
242
+
243
+ # Load Domain Configuration
244
+ domain_config = get_domain_config(request.domain)
245
+ DomainSchema = domain_config["schema"]
246
+ create_prompt = domain_config["create_prompt"]
247
+
248
+ # Validate Input Data
249
+ try:
250
+ validated_data = DomainSchema(**request.data)
251
+ except ValidationError as e:
252
+ raise HTTPException(status_code=422, detail=f"Invalid data: {e}")
253
+
254
+ # Prompt Construction
255
+ chat_messages = create_prompt(validated_data)
256
+
257
+ # Generate with each model
258
+ results = []
259
+
260
+ async def generate_with_model(model_name: str) -> ModelResult:
261
+ start_time = time.time()
262
+ try:
263
+ llm = await registry.get_model(model_name)
264
+ output = await llm.generate(
265
+ chat_messages=chat_messages,
266
+ max_new_tokens=150,
267
+ temperature=0.75,
268
+ top_p=0.9,
269
+ )
270
+ return ModelResult(
271
+ model=model_name,
272
+ output=output,
273
+ time=round(time.time() - start_time, 2),
274
+ type=MODEL_CONFIG[model_name]["type"],
275
+ error=None
276
+ )
277
+ except Exception as e:
278
+ return ModelResult(
279
+ model=model_name,
280
+ output="",
281
+ time=round(time.time() - start_time, 2),
282
+ type=MODEL_CONFIG[model_name]["type"],
283
+ error=str(e)
284
+ )
285
+
286
+ # Run all models (sequentially to avoid memory issues)
287
+ for model_name in models_to_use:
288
+ result = await generate_with_model(model_name)
289
+ results.append(result)
290
+
291
+ return CompareResponse(
292
+ domain=request.domain,
293
+ results=results,
294
+ total_time=round(time.time() - total_start, 2)
295
+ )
296
+
297
+ @app.get("/user/me")
298
+ async def get_user_info(user: dict = Depends(get_authenticated_user)):
299
+ """Get current authenticated user information"""
300
+ if not user:
301
+ raise HTTPException(status_code=401, detail="Not authenticated")
302
+ return {
303
+ "user_id": user['user_id'],
304
+ "email": user['email'],
305
+ "name": user.get('name', 'Unknown')
306
+ }
307
+
308
+
309
+ # --- Batch Infill Endpoints ---
310
+
311
+ @app.post("/infill", response_model=InfillResponse)
312
+ async def batch_infill(
313
+ request: InfillRequest,
314
+ user: Optional[dict] = Depends(get_authenticated_user)
315
+ ):
316
+ """
317
+ Batch gap-filling for ads using a single model.
318
+
319
+ Accepts items with [GAP:n] markers or ___ and returns filled text
320
+ with per-gap choices and alternatives.
321
+
322
+ NOTE: For texts > 6000 chars, consider chunking (not yet implemented).
323
+ """
324
+ print(f"DEBUG: Hit batch_infill endpoint with model={request.model}", flush=True)
325
+ total_start = time.time()
326
+
327
+ # Validate model
328
+ if request.model not in registry.get_available_model_names():
329
+ raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}")
330
+
331
+ # Load domain config for infill prompt
332
+ domain_config = get_domain_config(request.domain)
333
+ if "create_infill_prompt" not in domain_config:
334
+ raise HTTPException(
335
+ status_code=400,
336
+ detail=f"Domain '{request.domain}' does not support infill operations"
337
+ )
338
+ create_infill_prompt = domain_config["create_infill_prompt"]
339
+
340
+ # Process each item
341
+ results = []
342
+ error_count = 0
343
+
344
+ for item in request.items:
345
+ result = await process_infill_item(
346
+ item=item,
347
+ model_name=request.model,
348
+ options=request.options,
349
+ create_infill_prompt=create_infill_prompt
350
+ )
351
+ results.append(result)
352
+ if result.status == "error":
353
+ error_count += 1
354
+
355
+ return InfillResponse(
356
+ model=request.model,
357
+ results=results,
358
+ total_time=round(time.time() - total_start, 2),
359
+ processed_count=len(results),
360
+ error_count=error_count
361
+ )
362
+
363
+
364
+ @app.post("/compare-infill", response_model=CompareInfillResponse)
365
+ async def compare_infill(
366
+ request: CompareInfillRequest,
367
+ user: Optional[dict] = Depends(get_authenticated_user)
368
+ ):
369
+ """
370
+ Multi-model batch gap-filling comparison for A/B testing.
371
+
372
+ Runs the same batch of items through multiple models and returns
373
+ per-model results for comparison.
374
+ """
375
+ total_start = time.time()
376
+
377
+ # Get models to compare
378
+ available_models = registry.get_available_model_names()
379
+ models_to_use = request.models if request.models else available_models
380
+
381
+ # Validate requested models
382
+ for model in models_to_use:
383
+ if model not in available_models:
384
+ raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
385
+
386
+ # Load domain config
387
+ domain_config = get_domain_config(request.domain)
388
+ if "create_infill_prompt" not in domain_config:
389
+ raise HTTPException(
390
+ status_code=400,
391
+ detail=f"Domain '{request.domain}' does not support infill operations"
392
+ )
393
+ create_infill_prompt = domain_config["create_infill_prompt"]
394
+
395
+ # Process with each model (sequentially for memory safety)
396
+ model_results = []
397
+
398
+ for model_name in models_to_use:
399
+ model_start = time.time()
400
+ results = []
401
+ error_count = 0
402
+
403
+ for item in request.items:
404
+ result = await process_infill_item(
405
+ item=item,
406
+ model_name=model_name,
407
+ options=request.options,
408
+ create_infill_prompt=create_infill_prompt
409
+ )
410
+ results.append(result)
411
+ if result.status == "error":
412
+ error_count += 1
413
+
414
+ model_results.append(ModelInfillResult(
415
+ model=model_name,
416
+ type=MODEL_CONFIG[model_name]["type"],
417
+ results=results,
418
+ time=round(time.time() - model_start, 2),
419
+ error_count=error_count
420
+ ))
421
+
422
+ return CompareInfillResponse(
423
+ domain=request.domain,
424
+ models=model_results,
425
+ total_time=round(time.time() - total_start, 2)
426
+ )
427
+
428
+
429
+ async def process_infill_item(
430
+ item,
431
+ model_name: str,
432
+ options,
433
+ create_infill_prompt
434
+ ) -> InfillResult:
435
+ """
436
+ Process a single infill item.
437
+
438
+ Returns InfillResult with status, filled_text, and gaps.
439
+ """
440
+ try:
441
+ # Normalize gaps to [GAP:n] format
442
+ normalized_text, gaps = normalize_gaps_to_tagged(item.text_with_gaps)
443
+
444
+ if not gaps:
445
+ # No gaps found, return original text
446
+ return InfillResult(
447
+ id=item.id,
448
+ status="ok",
449
+ filled_text=item.text_with_gaps,
450
+ gaps=[],
451
+ error=None
452
+ )
453
+
454
+ # Build prompt
455
+ if item.custom_messages:
456
+ chat_messages = item.custom_messages
457
+ use_grammar = False # Custom messages = plain text output expected
458
+ else:
459
+ chat_messages = create_infill_prompt(normalized_text, options, attributes=item.attributes)
460
+ use_grammar = True # Standard prompt = use grammar for structured JSON
461
+
462
+ # Generate with optional GBNF grammar constraint
463
+ llm = await registry.get_model(model_name)
464
+
465
+ grammar_str = None
466
+ if use_grammar and hasattr(llm, 'llm') and llm.llm is not None:
467
+ # Use model's default grammar (loaded from answers.gbnf) if available
468
+ if hasattr(llm, 'default_grammar') and llm.default_grammar:
469
+ grammar_str = llm.default_grammar
470
+ print(f"DEBUG: Using model's default GBNF grammar", flush=True)
471
+ else:
472
+ # Fallback to dynamic grammar generation
473
+ try:
474
+ from app.logic.grammar_utils import get_infill_grammar
475
+ grammar_str = get_infill_grammar(len(gaps))
476
+ print(f"DEBUG: Using dynamic GBNF grammar for {len(gaps)} gaps", flush=True)
477
+ except ImportError:
478
+ pass
479
+
480
+ raw_output = await llm.generate(
481
+ chat_messages=chat_messages,
482
+ max_new_tokens=options.max_new_tokens,
483
+ temperature=0.3 if use_grammar else options.temperature, # Lower temp with grammar
484
+ top_p=0.9,
485
+ grammar=grammar_str,
486
+ )
487
+
488
+ # If custom_messages were provided, the output is plain text (not JSON)
489
+ # Just return it directly as a single gap fill
490
+ if item.custom_messages:
491
+ # Clean up the raw output - strip whitespace, quotes, etc.
492
+ choice = raw_output.strip().strip('"\'.,').strip()
493
+ return InfillResult(
494
+ id=item.id,
495
+ status="ok",
496
+ filled_text=choice, # The filled text is just the choice itself
497
+ gaps=[GapFill(index=1, marker="[GAP:1]", choice=choice, alternatives=[])],
498
+ error=None
499
+ )
500
+
501
+ # Parse JSON from output (standard prompt format)
502
+ parsed = parse_infill_response(raw_output)
503
+
504
+ if not parsed:
505
+ # JSON parsing failed
506
+ return InfillResult(
507
+ id=item.id,
508
+ status="error",
509
+ filled_text=None,
510
+ gaps=[],
511
+ error=f"Failed to parse JSON from model output: {raw_output[:200]}..."
512
+ )
513
+
514
+ # Extract gaps and build result
515
+ gap_fills = []
516
+ fills_dict = {}
517
+
518
+ for gap_data in parsed.get("gaps", []):
519
+ gap_fill = GapFill(
520
+ index=gap_data.get("index", 0),
521
+ marker=gap_data.get("marker", ""),
522
+ choice=gap_data.get("choice", ""),
523
+ alternatives=gap_data.get("alternatives", [])
524
+ )
525
+ gap_fills.append(gap_fill)
526
+ fills_dict[gap_fill.index] = gap_fill.choice
527
+
528
+ # Get filled text - prefer model's version, fallback to reconstruction
529
+ filled_text = parsed.get("filled_text")
530
+ if not filled_text and fills_dict:
531
+ filled_text = apply_fills(normalized_text, gaps, fills_dict)
532
+
533
+ return InfillResult(
534
+ id=item.id,
535
+ status="ok",
536
+ filled_text=filled_text,
537
+ gaps=gap_fills,
538
+ error=None
539
+ )
540
+
541
+ except Exception as e:
542
+ return InfillResult(
543
+ id=item.id,
544
+ status="error",
545
+ filled_text=None,
546
+ gaps=[],
547
+ error=str(e)
548
+ )
app/main_simple.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ from typing import Optional, List
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+
8
+ # Install llama-cpp-python with CUDA support at runtime
9
+ try:
10
+ import llama_cpp
11
+ except ImportError:
12
+ print("[STARTUP] Installing llama-cpp-python with CUDA...")
13
+ result = subprocess.run(
14
+ [sys.executable, "-m", "pip", "install", "--quiet", "--prefer-binary",
15
+ "--index-url", "https://abetlen.github.io/llama-cpp-python/whl/cu121",
16
+ "llama-cpp-python[server]>=0.3.16"],
17
+ capture_output=True,
18
+ text=True
19
+ )
20
+ if result.returncode != 0:
21
+ print("[STARTUP] CUDA wheel failed, trying CPU fallback...")
22
+ subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "llama-cpp-python>=0.3.16"], check=False)
23
+
24
+ from app.models.registry import registry, MODEL_CONFIG
25
+
26
+ # Request/Response Models
27
+ class Message(BaseModel):
28
+ role: str
29
+ content: str
30
+
31
+ class ChatRequest(BaseModel):
32
+ model: str
33
+ messages: List[Message]
34
+ max_tokens: int = 150
35
+ temperature: float = 0.7
36
+ top_p: float = 0.9
37
+
38
+ class ChatChoice(BaseModel):
39
+ message: Message
40
+ finish_reason: str
41
+
42
+ class ChatResponse(BaseModel):
43
+ model: str
44
+ choices: List[ChatChoice]
45
+ usage: dict
46
+
47
+ class GenerateRequest(BaseModel):
48
+ model: str
49
+ prompt: str
50
+ max_tokens: int = 150
51
+ temperature: float = 0.7
52
+ top_p: float = 0.9
53
+
54
+ class GenerateResponse(BaseModel):
55
+ model: str
56
+ text: str
57
+ tokens_generated: int
58
+
59
+ class ModelInfo(BaseModel):
60
+ name: str
61
+ type: str
62
+ device: str = "unknown"
63
+
64
+ class ModelsResponse(BaseModel):
65
+ models: List[ModelInfo]
66
+
67
+ class HealthResponse(BaseModel):
68
+ status: str
69
+ gpu_available: bool
70
+ models_available: int
71
+
72
+ # Create app
73
+ app = FastAPI(
74
+ title="Bielik LLM Service",
75
+ description="Pure inference service for Bielik models with GPU acceleration",
76
+ version="2.0.0"
77
+ )
78
+
79
+ @app.on_event("startup")
80
+ async def startup_event():
81
+ """Initialize service on startup."""
82
+ print("Application started. Models will be loaded lazily on first request.")
83
+ print(f"Available models: {registry.get_available_model_names()}")
84
+
85
+ try:
86
+ import torch
87
+ gpu_available = torch.cuda.is_available()
88
+ gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A"
89
+ print(f"GPU available: {gpu_available}, Device: {gpu_name}")
90
+ except ImportError:
91
+ print("PyTorch not available for GPU check")
92
+ except Exception as e:
93
+ print(f"GPU check failed: {e}")
94
+
95
+ @app.get("/health", response_model=HealthResponse)
96
+ async def health_check():
97
+ """Health check endpoint."""
98
+ gpu_available = False
99
+ try:
100
+ import torch
101
+ gpu_available = torch.cuda.is_available()
102
+ except:
103
+ pass
104
+
105
+ return HealthResponse(
106
+ status="ok",
107
+ gpu_available=gpu_available,
108
+ models_available=len(registry.get_available_model_names())
109
+ )
110
+
111
+ @app.get("/models", response_model=ModelsResponse)
112
+ async def list_models():
113
+ """List all available models."""
114
+ models_list = []
115
+ for model_name in registry.get_available_model_names():
116
+ info = registry.get_model_info(model_name)
117
+ models_list.append(ModelInfo(
118
+ name=model_name,
119
+ type=info.get("type", "unknown"),
120
+ device=info.get("device", "unknown")
121
+ ))
122
+ return ModelsResponse(models=models_list)
123
+
124
+ @app.post("/chat", response_model=ChatResponse)
125
+ async def chat_completion(request: ChatRequest):
126
+ """
127
+ Chat completion endpoint (OpenAI compatible).
128
+
129
+ Accepts a list of messages and returns a completion.
130
+ """
131
+ # Validate model
132
+ if request.model not in registry.get_available_model_names():
133
+ raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}")
134
+
135
+ try:
136
+ # Load model
137
+ llm = await registry.get_model(request.model)
138
+
139
+ # Convert messages to list of dicts
140
+ messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
141
+
142
+ # Generate
143
+ output = await llm.generate(
144
+ chat_messages=messages,
145
+ max_new_tokens=request.max_tokens,
146
+ temperature=request.temperature,
147
+ top_p=request.top_p,
148
+ )
149
+
150
+ return ChatResponse(
151
+ model=request.model,
152
+ choices=[ChatChoice(
153
+ message=Message(role="assistant", content=output),
154
+ finish_reason="stop"
155
+ )],
156
+ usage={
157
+ "prompt_tokens": sum(len(msg.get("content", "").split()) for msg in messages),
158
+ "completion_tokens": len(output.split())
159
+ }
160
+ )
161
+ except Exception as e:
162
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
163
+
164
+ @app.post("/generate", response_model=GenerateResponse)
165
+ async def generate_text(request: GenerateRequest):
166
+ """
167
+ Raw text generation endpoint.
168
+
169
+ Accepts a prompt string and returns generated text.
170
+ """
171
+ # Validate model
172
+ if request.model not in registry.get_available_model_names():
173
+ raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}")
174
+
175
+ try:
176
+ # Load model
177
+ llm = await registry.get_model(request.model)
178
+
179
+ # Generate
180
+ output = await llm.generate(
181
+ prompt=request.prompt,
182
+ max_new_tokens=request.max_tokens,
183
+ temperature=request.temperature,
184
+ top_p=request.top_p,
185
+ )
186
+
187
+ return GenerateResponse(
188
+ model=request.model,
189
+ text=output,
190
+ tokens_generated=len(output.split())
191
+ )
192
+ except Exception as e:
193
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
194
+
195
+ @app.get("/")
196
+ async def root():
197
+ """Root endpoint."""
198
+ return {
199
+ "message": "Bielik LLM Service",
200
+ "docs": "/docs",
201
+ "endpoints": ["/chat", "/generate", "/models", "/health"]
202
+ }
test_simplified.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for simplified Bielik service
3
+ Tests the API structure without running actual models
4
+ """
5
+ import os
6
+ import json
7
+ from unittest.mock import Mock, AsyncMock, patch
8
+
9
+ # Skip llama-cpp installation during testing
10
+ os.environ["SKIP_LLAMA_INSTALL"] = "1"
11
+
12
+ # Mock the registry before importing main
13
+ mock_registry = Mock()
14
+ mock_registry.get_available_model_names.return_value = ["bielik-1.5b-transformer", "bielik-11b-transformer"]
15
+ mock_registry.get_model_info.return_value = {"type": "transformers", "device": "cuda:0"}
16
+
17
+ @patch("app.main.registry", mock_registry)
18
+ def test_app_structure():
19
+ """Test that simplified app has correct endpoints"""
20
+ from app.main import app
21
+
22
+ # Get all routes
23
+ routes = {route.path: route.methods for route in app.routes}
24
+
25
+ # Check required endpoints exist
26
+ assert "/" in routes, "Root endpoint missing"
27
+ assert "/health" in routes, "Health endpoint missing"
28
+ assert "/models" in routes, "Models endpoint missing"
29
+ assert "/chat" in routes, "Chat endpoint missing"
30
+ assert "/generate" in routes, "Generate endpoint missing"
31
+
32
+ # Check methods
33
+ assert "GET" in routes["/health"], "Health should be GET"
34
+ assert "GET" in routes["/models"], "Models should be GET"
35
+ assert "POST" in routes["/chat"], "Chat should be POST"
36
+ assert "POST" in routes["/generate"], "Generate should be POST"
37
+
38
+ print("βœ… App structure correct")
39
+ print(f" Routes: {list(routes.keys())}")
40
+
41
+ @patch("app.main.registry", mock_registry)
42
+ def test_no_business_logic():
43
+ """Verify no domain/infill endpoints exist"""
44
+ from app.main import app
45
+
46
+ routes = {route.path for route in app.routes}
47
+
48
+ # These should NOT exist
49
+ forbidden_routes = ["/enhance", "/compare", "/infill", "/compare-infill", "/user/me"]
50
+
51
+ for route in forbidden_routes:
52
+ assert route not in routes, f"Business logic endpoint {route} should not exist"
53
+
54
+ print("βœ… No business logic endpoints found")
55
+
56
+ @patch("app.main.registry", mock_registry)
57
+ def test_request_schemas():
58
+ """Test request/response schemas are valid"""
59
+ from app.main import ChatRequest, GenerateRequest, ChatResponse, GenerateResponse
60
+ from app.main import Message, HealthResponse, ModelsResponse
61
+
62
+ # Test ChatRequest
63
+ chat_req = ChatRequest(
64
+ model="bielik-1.5b-transformer",
65
+ messages=[Message(role="user", content="Hello")]
66
+ )
67
+ assert chat_req.model == "bielik-1.5b-transformer"
68
+ assert len(chat_req.messages) == 1
69
+ print("βœ… ChatRequest schema valid")
70
+
71
+ # Test GenerateRequest
72
+ gen_req = GenerateRequest(
73
+ model="bielik-1.5b-transformer",
74
+ prompt="Hello world"
75
+ )
76
+ assert gen_req.model == "bielik-1.5b-transformer"
77
+ assert gen_req.prompt == "Hello world"
78
+ print("βœ… GenerateRequest schema valid")
79
+
80
+ # Test HealthResponse
81
+ health = HealthResponse(
82
+ status="ok",
83
+ gpu_available=True,
84
+ models_available=2
85
+ )
86
+ assert health.status == "ok"
87
+ print("βœ… HealthResponse schema valid")
88
+
89
+ # Test ModelsResponse
90
+ models_resp = ModelsResponse(models=[])
91
+ assert isinstance(models_resp.models, list)
92
+ print("βœ… ModelsResponse schema valid")
93
+
94
+ @patch("app.main.registry", mock_registry)
95
+ def test_default_values():
96
+ """Test that requests have sensible defaults"""
97
+ from app.main import ChatRequest, GenerateRequest, Message
98
+
99
+ # Chat with minimal fields
100
+ chat = ChatRequest(
101
+ model="test",
102
+ messages=[Message(role="user", content="test")]
103
+ )
104
+ assert chat.max_tokens == 150
105
+ assert chat.temperature == 0.7
106
+ assert chat.top_p == 0.9
107
+ print("βœ… Chat defaults correct")
108
+
109
+ # Generate with minimal fields
110
+ gen = GenerateRequest(
111
+ model="test",
112
+ prompt="test"
113
+ )
114
+ assert gen.max_tokens == 150
115
+ assert gen.temperature == 0.7
116
+ assert gen.top_p == 0.9
117
+ print("βœ… Generate defaults correct")
118
+
119
+ if __name__ == "__main__":
120
+ print("\n=== Testing Simplified Bielik Service ===\n")
121
+
122
+ try:
123
+ test_app_structure()
124
+ test_no_business_logic()
125
+ test_request_schemas()
126
+ test_default_values()
127
+
128
+ print("\nβœ… All tests passed!")
129
+ print("\n=== Phase 1 Verification Complete ===")
130
+ except AssertionError as e:
131
+ print(f"\n❌ Test failed: {e}")
132
+ exit(1)