SsebaA commited on
Commit
a222a41
Β·
verified Β·
1 Parent(s): d1a3550

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +129 -100
models.py CHANGED
@@ -1,125 +1,154 @@
1
  """
2
  VoiceNote AI - Models
3
- Mistral AI client (HTTP API) and Whisper ASR model
4
  """
5
 
6
  import logging
7
- import requests
8
  import torch
9
- from transformers import pipeline
 
 
10
  from config import Config
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class MistralClient:
16
- """Mistral AI HTTP API Client"""
17
-
 
 
 
18
  def __init__(self):
19
- """Initialize Mistral client with HTTP API"""
20
- self.api_key = Config.MISTRAL_API_KEY
21
- self.api_url = Config.MISTRAL_API_URL
22
- self.model = Config.MISTRAL_MODEL
23
-
24
- if not self.api_key:
25
- raise ValueError("MISTRAL_API_KEY not found in environment")
26
-
27
- logger.info("Mistral client initialized with HTTP API")
28
-
29
  def generate(self, prompt: str, max_tokens: int = 500, temperature: float = 0.1) -> str:
30
  """
31
- Generate text using Mistral AI API
32
-
33
  Args:
34
- prompt: Input prompt
35
- max_tokens: Maximum tokens to generate
36
- temperature: Sampling temperature
37
-
38
  Returns:
39
- Generated text
40
  """
41
  headers = {
 
42
  "Content-Type": "application/json",
43
- "Authorization": f"Bearer {self.api_key}"
44
  }
45
-
46
  payload = {
47
- "model": self.model,
48
- "messages": [
49
- {"role": "user", "content": prompt}
50
- ],
51
  "max_tokens": max_tokens,
52
- "temperature": temperature
53
  }
54
-
55
- try:
56
- response = requests.post(
57
- self.api_url,
58
- headers=headers,
59
- json=payload,
60
- timeout=30
61
- )
62
- response.raise_for_status()
63
-
64
- result = response.json()
65
- return result['choices'][0]['message']['content']
66
-
67
- except requests.exceptions.RequestException as e:
68
- logger.error(f"Mistral API error: {e}")
69
- raise
70
-
71
-
72
- class ASRModel:
73
- """Automatic Speech Recognition using Whisper"""
74
-
75
- def __init__(self):
76
- """Initialize Whisper ASR model"""
77
- self.model_name = Config.ASR_MODEL_NAME
78
- self.language = Config.ASR_LANGUAGE
79
- self.device = Config.ASR_DEVICE
80
- self.dtype = Config.ASR_DTYPE
81
-
82
- logger.info(f"Loading ASR model: {self.model_name}")
83
-
84
- # Convert dtype string to torch dtype
85
- torch_dtype = torch.float32 if self.dtype == "float32" else torch.float16
86
-
87
- # Load model on CPU with float32 to avoid GPU dtype issues
88
- # Enable long-form transcription with chunk_length_s
89
- self.pipe = pipeline(
90
- "automatic-speech-recognition",
91
- model=self.model_name,
92
- device=self.device,
93
- torch_dtype=torch_dtype,
94
- chunk_length_s=30, # Enable chunking for long audio (>30s)
95
- return_timestamps=False # Don't return timestamps, just text
96
  )
97
-
98
- logger.info(f"ASR model loaded successfully on {self.device}")
99
-
100
- def transcribe(self, audio_path: str) -> str:
101
- """
102
- Transcribe audio file to text
103
-
104
- Args:
105
- audio_path: Path to audio file
106
-
107
- Returns:
108
- Transcribed text
109
- """
110
- logger.info(f"Transcribing audio: {audio_path}")
111
-
112
- try:
113
- # Pass language in generate_kwargs, NOT in model initialization
114
- result = self.pipe(
115
- audio_path,
116
- generate_kwargs={"language": self.language}
117
- )
118
-
119
- text = result["text"].strip()
120
- logger.info(f"Transcription successful: {len(text)} characters")
121
- return text
122
-
123
- except Exception as e:
124
- logger.error(f"Transcription error: {e}")
125
- raise
 
1
  """
2
  VoiceNote AI - Models
3
+ ASR (Whisper) and LLM (Mistral) clients + DeepL translation layer
4
  """
5
 
6
  import logging
 
7
  import torch
8
+ import deepl
9
+ import requests
10
+ from transformers import pipeline as hf_pipeline
11
  from config import Config
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
+ # ══════════════════════════════════════════════════════════
17
+ # ASR β€” KBLab fine-tuned Whisper for Swedish
18
+ # ══════════════════════════════════════════════════════════
19
+
20
+ class WhisperASR:
21
+ """
22
+ Swedish ASR using KBLab's fine-tuned Whisper model.
23
+
24
+ KBLab/whisper-large-v3-swedish is trained on Swedish speech corpora,
25
+ significantly outperforming openai/whisper-small on Swedish medical text.
26
+ Reference: Vesterbacka et al. (2025), 'Swedish Whispers'.
27
+
28
+ The model runs locally on ZeroGPU β€” no audio leaves the server.
29
+ Audio is split into 30-second chunks to avoid GPU memory issues.
30
+ """
31
+
32
+ def __init__(self):
33
+ self._pipe = None
34
+
35
+ def _load(self):
36
+ """Lazy-load the model on first call (ZeroGPU requires GPU context)."""
37
+ if self._pipe is None:
38
+ logger.info(f"Loading ASR model: {Config.ASR_MODEL_NAME}")
39
+ self._pipe = hf_pipeline(
40
+ task="automatic-speech-recognition",
41
+ model=Config.ASR_MODEL_NAME,
42
+ torch_dtype=torch.float16,
43
+ device="cuda",
44
+ )
45
+ return self._pipe
46
+
47
+ def transcribe(self, audio_path: str) -> str:
48
+ """
49
+ Transcribe a Swedish audio file to text.
50
+
51
+ Args:
52
+ audio_path: Path to audio file (wav/mp3/m4a)
53
+
54
+ Returns:
55
+ Transcribed Swedish text
56
+ """
57
+ pipe = self._load()
58
+ result = pipe(
59
+ audio_path,
60
+ generate_kwargs={"language": Config.ASR_LANGUAGE, "task": "transcribe"},
61
+ chunk_length_s=Config.ASR_CHUNK_LENGTH_S,
62
+ stride_length_s=Config.ASR_STRIDE_LENGTH_S,
63
+ return_timestamps=False,
64
+ )
65
+ return result["text"].strip()
66
+
67
+
68
+ # ══════════════════════════════════════════════════════════
69
+ # TRANSLATION β€” DeepL (Frankfurt, within EU)
70
+ # ══════════════════════════════════════════════════════════
71
+
72
+ class DeepLTranslator:
73
+ """
74
+ Translates anonymized Swedish text to English via DeepL API.
75
+
76
+ Why: Mistral AI has limited Swedish NLP capability. Zero-shot and
77
+ Chain-of-Thought prompting in Swedish often produces empty or
78
+ incorrect VIPS output. Translating to English first resolves this
79
+ while keeping all data within EU jurisdiction (DeepL Frankfurt).
80
+
81
+ Data flow remains GDPR-compliant:
82
+ Whisper [local GPU] β†’ GDPR filter [local] β†’ DeepL [Frankfurt πŸ‡©πŸ‡ͺ]
83
+ β†’ Mistral [Paris πŸ‡«πŸ‡·] β†’ result
84
+ """
85
+
86
+ def __init__(self):
87
+ if not Config.DEEPL_API_KEY:
88
+ raise EnvironmentError("DEEPL_API_KEY saknas i HuggingFace Secrets.")
89
+ self._translator = deepl.Translator(Config.DEEPL_API_KEY)
90
+
91
+ def translate(self, swedish_text: str) -> str:
92
+ """
93
+ Translate Swedish text to English.
94
+
95
+ Args:
96
+ swedish_text: Anonymized Swedish patient text
97
+
98
+ Returns:
99
+ English translation
100
+ """
101
+ result = self._translator.translate_text(
102
+ swedish_text,
103
+ source_lang="SV",
104
+ target_lang="EN-US",
105
+ )
106
+ logger.info("DeepL translation completed (SV β†’ EN)")
107
+ return result.text
108
+
109
+
110
+ # ══════════════════════════════════════════════════════════
111
+ # LLM β€” Mistral AI (Paris, within EU)
112
+ # ══════════════════════════════════════════════════════════
113
+
114
  class MistralClient:
115
+ """
116
+ Mistral AI client for VIPS classification.
117
+ Mistral is based in Paris (France) β€” data stays within EU/GDPR.
118
+ """
119
+
120
  def __init__(self):
121
+ if not Config.MISTRAL_API_KEY:
122
+ raise EnvironmentError("MISTRAL_API_KEY saknas i HuggingFace Secrets.")
123
+ self._api_key = Config.MISTRAL_API_KEY
124
+
 
 
 
 
 
 
125
  def generate(self, prompt: str, max_tokens: int = 500, temperature: float = 0.1) -> str:
126
  """
127
+ Send a prompt to Mistral and return the generated text.
128
+
129
  Args:
130
+ prompt: Full prompt string (system + user content)
131
+ max_tokens: Maximum tokens in response
132
+ temperature: Sampling temperature (low = more deterministic)
133
+
134
  Returns:
135
+ Model response text
136
  """
137
  headers = {
138
+ "Authorization": f"Bearer {self._api_key}",
139
  "Content-Type": "application/json",
 
140
  }
 
141
  payload = {
142
+ "model": Config.MISTRAL_MODEL,
143
+ "messages": [{"role": "user", "content": prompt}],
 
 
144
  "max_tokens": max_tokens,
145
+ "temperature": temperature,
146
  }
147
+ response = requests.post(
148
+ "https://api.mistral.ai/v1/chat/completions",
149
+ headers=headers,
150
+ json=payload,
151
+ timeout=60,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
+ response.raise_for_status()
154
+ return response.json()["choices"][0]["message"]["content"].strip()