Nelya / app.py
Clemylia's picture
Update app.py
69e5a19 verified
Raw
History Blame Contribute Delete
2.24 kB
import os
import sqlite3
import torch
import uvicorn
from transformers import AutoModelForCausalLM, AutoTokenizer
from fastapi_poe import PoeBot, make_app
from fastapi_poe.types import QueryRequest, PartialResponse
# --- CONFIGURATION DU MODÈLE ---
MODEL_ID = "Finisha-F-scratch/Nelya-neko"
print("--> Chargement du tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("--> Chargement du modèle (cette étape peut prendre du temps)...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
# --- BASE DE DONNÉES ---
DB_NAME = os.path.join(os.path.expanduser("~"), "charlotte_api.db")
def init_db():
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS api_keys (key TEXT PRIMARY KEY, requests_count INTEGER)''')
conn.commit()
conn.close()
init_db()
# --- CLASSE DU BOT POE ---
class CharlottePoeBot(PoeBot):
async def get_response(self, request: QueryRequest):
# Récupérer le dernier message envoyé par l'utilisateur
last_message = request.query[-1].content
# Préparation des tokens
inputs = tokenizer(last_message, return_tensors="pt", truncation=True, max_length=128).to(model.device)
# Génération du texte
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
response_text = full_text[len(last_message):].strip()
# CORRECTION : Utilisation du format natif PartialResponse de Poe
yield PartialResponse(text=response_text)
# --- CREATION DE L'APPLICATION ---
bot = CharlottePoeBot()
app = make_app(bot)
# --- LANCEMENT DU SERVEUR ---
if __name__ == "__main__":
port_to_use = int(os.environ.get("PORT", 7860))
print(f"--> Démarrage d'Uvicorn sur le port {port_to_use}...")
uvicorn.run(app, host="0.0.0.0", port=port_to_use)