Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from pydantic import BaseModel | |
| app = FastAPI() | |
| # 📌 Définir un dossier cache accessible | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
| # 📌 Charger le modèle et le tokenizer avec cache local | |
| MODEL_NAME = "fatmata/psybot" | |
| local_dir = "/tmp/model" | |
| os.makedirs(local_dir, exist_ok=True) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32) | |
| # 📌 Définition du modèle pour recevoir l'entrée utilisateur | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| def home(): | |
| return {"message": "Bienvenue sur l'API PsyBot !"} | |
| def generate_text(request: PromptRequest): | |
| """ Génère une réponse du chatbot PsyBot """ | |
| user_input = request.prompt | |
| # 📌 Ajouter les balises pour respecter le format du modèle | |
| formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>" | |
| # 📌 Encodage du texte et génération de la réponse | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| inputs, | |
| max_new_tokens=100, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=True, # Activation du sampling | |
| temperature=0.7, # Génération plus naturelle | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2 # Réduction de la répétition | |
| ) | |
| response = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|> | |
| if "<|bot|>" in response: | |
| response = response.split("<|bot|>")[-1].strip() | |
| return {"response": response} | |