| | |
| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| | import torch |
| | import os |
| |
|
| | |
| | |
| | model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| | device = torch.device("cpu") |
| |
|
| | |
| | |
| | |
| | quantization_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | bnb_4bit_use_double_quant=True, |
| | ) |
| |
|
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | |
| | |
| | print("Tentative de chargement avec quantisation 4-bit...") |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | quantization_config=quantization_config, |
| | device_map="auto", |
| | trust_remote_code=True |
| | ) |
| | print(f"Modèle {model_id} chargé et quantifié.") |
| |
|
| | except Exception as e_quant: |
| | |
| | print(f"Échec de la quantisation : {e_quant}. Tentative de chargement float32 CPU (Attention: peut causer OOM).") |
| | |
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.float32, |
| | trust_remote_code=True |
| | ).to(device) |
| | print(f"Modèle {model_id} chargé sur CPU (Float32).") |
| | except Exception as e_cpu: |
| | print(f"Échec critique du chargement CPU : {e_cpu}") |
| | |
| | raise e_cpu |
| |
|
| | model.eval() |
| |
|
| | app = FastAPI( |
| | title="NLP Space - Phi-3 Mini API (CPU)", |
| | description="API REST pour génération, résumé et classification de texte, optimisée pour CPU." |
| | ) |
| |
|
| | |
| | class PromptRequest(BaseModel): |
| | prompt: str |
| | max_tokens: int = 500 |
| | temperature: float = 0.7 |
| |
|
| | |
| | def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: int, temperature: float): |
| | |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": user_prompt} |
| | ] |
| | |
| | text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) |
| | |
| | |
| | |
| | real_device = model.device if model.device.type != 'meta' else torch.device("cpu") |
| | inputs = tokenizer(text_to_generate, return_tensors="pt").to(real_device) |
| | |
| | with torch.no_grad(): |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | do_sample=True, |
| | temperature=temperature, |
| | pad_token_id=tokenizer.eos_token_id, |
| | use_cache=False |
| | ) |
| | |
| | generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| | |
| | |
| | response_start_tag = "<|assistant|>" |
| | if response_start_tag in generated_text: |
| | return generated_text.split(response_start_tag, 1)[1].strip() |
| | |
| | return generated_text.strip() |
| |
|
| |
|
| | |
| |
|
| | @app.post("/generate") |
| | async def generate(request: PromptRequest): |
| | """Génération de texte libre.""" |
| | system_prompt = "Tu es un assistant IA très utile et créatif." |
| | try: |
| | result = generate_text_from_model( |
| | system_prompt=system_prompt, |
| | user_prompt=request.prompt, |
| | max_tokens=request.max_tokens, |
| | temperature=request.temperature |
| | ) |
| | return {"result": result} |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | @app.post("/summarize") |
| | async def summarize(request: PromptRequest): |
| | |
| | system_prompt = "Tu es un expert en résumé concis et précis. Ton objectif est de résumer le texte fourni de manière à en conserver l'idée principale." |
| | user_prompt = f"Résume le texte suivant de manière concise et factuelle:\n\n---\n\n{request.prompt}" |
| | try: |
| | result = generate_text_from_model( |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | max_tokens=request.max_tokens, |
| | temperature=0.3 |
| | ) |
| | return {"result": result} |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | @app.post("/classify") |
| | async def classify(request: PromptRequest): |
| | |
| | system_prompt = "Tu es un expert en classification. Réponds uniquement avec l'étiquette de classification demandée sans phrases supplémentaires." |
| | user_prompt = request.prompt |
| | try: |
| | result = generate_text_from_model( |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | max_tokens=50, |
| | temperature=0.1 |
| | ) |
| | return {"result": result} |
| | except Exception as e: |
| | return {"error": str(e)} |