x / models.py
SsebaA's picture
Update models.py
18f455f verified
"""
VoiceNote AI - Models
ASR: KBLab Whisper (local GPU)
Translation: DeepL (Frankfurt πŸ‡©πŸ‡ͺ)
LLM: Scaleway Generative API via OpenAI SDK (Paris πŸ‡«πŸ‡·)
"""
import logging
import torch
import deepl
from openai import OpenAI
from transformers import pipeline as hf_pipeline
from config import Config
logger = logging.getLogger(__name__)
class WhisperASR:
def __init__(self):
self._pipe = None
def _load(self):
if self._pipe is None:
logger.info(f"Loading ASR model: {Config.ASR_MODEL_NAME}")
self._pipe = hf_pipeline(
task="automatic-speech-recognition",
model=Config.ASR_MODEL_NAME,
torch_dtype=torch.float16,
device="cuda",
)
return self._pipe
def transcribe(self, audio_path: str) -> str:
pipe = self._load()
result = pipe(
audio_path,
generate_kwargs={"language": Config.ASR_LANGUAGE, "task": "transcribe"},
chunk_length_s=Config.ASR_CHUNK_LENGTH_S,
stride_length_s=Config.ASR_STRIDE_LENGTH_S,
return_timestamps=False,
)
return result["text"].strip()
class DeepLTranslator:
def __init__(self):
if not Config.DEEPL_API_KEY:
raise EnvironmentError("DEEPL_API_KEY saknas.")
self._translator = deepl.Translator(Config.DEEPL_API_KEY)
def translate(self, swedish_text: str) -> str:
result = self._translator.translate_text(
swedish_text, source_lang="SV", target_lang="EN-US"
)
return result.text
class MistralClient:
"""
Scaleway Generative API β€” OpenAI SDK compatible.
Paris πŸ‡«πŸ‡· (OPCORE datacenter), 100% EU GDPR.
Free tier: 1M tokens. First token latency <1s.
"""
def __init__(self):
if not Config.SCALEWAY_API_KEY:
raise EnvironmentError("SCALEWAY_API_KEY saknas i HuggingFace Secrets.")
self._client = OpenAI(
base_url="https://api.scaleway.ai/v1",
api_key=Config.SCALEWAY_API_KEY,
)
def generate(self, prompt: str, max_tokens: int = 500, temperature: float = 0.1) -> str:
response = self._client.chat.completions.create(
model=Config.SCALEWAY_MODEL,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=max(temperature, 0.15), # Scaleway recommends >=0.15
timeout=25,
)
return response.choices[0].message.content.strip()