textgen / app.py
1dm's picture
use tyniL
546aaea verified
# Fichier: app.py (VERSION CORRIGÉE FINALE - OPTIMISÉE POUR LA MÉMOIRE DU SPACE)
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# --- Configuration du Modèle ---
#model_id = "microsoft/Phi-3-mini-4k-instruct"
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = torch.device("cpu")
# --- Stratégie de chargement pour économiser la mémoire (Quantisation) ---
# Si le Space a un GPU/CUDA, la quantisation sera utilisée, réduisant la RAM par 8.
# Si le Space est CPU seulement, cette tentative échouera, et nous utiliserons le fallback float32.
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Charger le Tokenizer et le Modèle
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# TENTATIVE 1 : Chargement avec Quantisation 4-bit (Méthode recommandée)
print("Tentative de chargement avec quantisation 4-bit...")
# Le chargement en 4-bit nécessite device_map="auto"
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
print(f"Modèle {model_id} chargé et quantifié.")
except Exception as e_quant:
# Si la quantisation échoue (souvent sans GPU), on revient à la version CPU
print(f"Échec de la quantisation : {e_quant}. Tentative de chargement float32 CPU (Attention: peut causer OOM).")
# TENTATIVE 2 : Fallback sur le chargement float32 CPU (Votre code initial, mais avec fix du bug)
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
trust_remote_code=True
).to(device)
print(f"Modèle {model_id} chargé sur CPU (Float32).")
except Exception as e_cpu:
print(f"Échec critique du chargement CPU : {e_cpu}")
# Si même float32 échoue, vous avez BESOIN de plus de RAM pour votre Space.
raise e_cpu
model.eval()
app = FastAPI(
title="NLP Space - Phi-3 Mini API (CPU)",
description="API REST pour génération, résumé et classification de texte, optimisée pour CPU."
)
# Schéma de données pour les requêtes POST
class PromptRequest(BaseModel):
prompt: str
max_tokens: int = 500
temperature: float = 0.7
# Fonction utilitaire pour interagir avec le modèle
def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: int, temperature: float):
# Formatage de l'instruction pour Phi-3 Instruct (Chat Template)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
# Trouver le device réel du modèle pour y placer les inputs (nécessaire après le chargement device_map)
# Assurez-vous que le modèle est correctement placé, en le forçant sur CPU si nécessaire.
real_device = model.device if model.device.type != 'meta' else torch.device("cpu")
inputs = tokenizer(text_to_generate, return_tensors="pt").to(real_device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
use_cache=False # CORRECTION CRITIQUE 2: Fixe le bug DynamicCache
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Nettoyage
response_start_tag = "<|assistant|>"
if response_start_tag in generated_text:
return generated_text.split(response_start_tag, 1)[1].strip()
return generated_text.strip()
# --- Endpoints (Inchangés) ---
@app.post("/generate")
async def generate(request: PromptRequest):
"""Génération de texte libre."""
system_prompt = "Tu es un assistant IA très utile et créatif."
try:
result = generate_text_from_model(
system_prompt=system_prompt,
user_prompt=request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature
)
return {"result": result}
except Exception as e:
return {"error": str(e)}
@app.post("/summarize")
async def summarize(request: PromptRequest):
# ... (code inchangé) ...
system_prompt = "Tu es un expert en résumé concis et précis. Ton objectif est de résumer le texte fourni de manière à en conserver l'idée principale."
user_prompt = f"Résume le texte suivant de manière concise et factuelle:\n\n---\n\n{request.prompt}"
try:
result = generate_text_from_model(
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=request.max_tokens,
temperature=0.3
)
return {"result": result}
except Exception as e:
return {"error": str(e)}
@app.post("/classify")
async def classify(request: PromptRequest):
# ... (code inchangé) ...
system_prompt = "Tu es un expert en classification. Réponds uniquement avec l'étiquette de classification demandée sans phrases supplémentaires."
user_prompt = request.prompt
try:
result = generate_text_from_model(
system_prompt=system_prompt,
user_prompt=user_prompt,
max_tokens=50,
temperature=0.1
)
return {"result": result}
except Exception as e:
return {"error": str(e)}