salitako2.0 / main.py
Jay162005's picture
Upload 3 files
9bd3930 verified
import re
import math
import socket
import sqlite3
import datetime
import numpy as np
from scipy.signal import butter, sosfilt
from scipy.io import wavfile
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import tempfile
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from faster_whisper import WhisperModel
from zeroconf import ServiceInfo
from zeroconf.asyncio import AsyncZeroconf
# mDNS Service Configuration
SERVICE_TYPE = "_salitako._tcp.local."
SERVICE_NAME = "SalitaKo Server._salitako._tcp.local."
SERVICE_PORT = 8000
# Cloud deployment detection (Hugging Face Spaces, Railway, etc.)
IS_CLOUD = os.environ.get("SPACE_ID") is not None or os.environ.get("RAILWAY_ENVIRONMENT") is not None
# Service Mode Configuration (Split Architecture)
SERVICE_MODE = os.environ.get("SERVICE_MODE", "audio").lower() # 'audio' or 'nlp'
NLP_API_URL = os.environ.get("NLP_API_URL", "").rstrip("/")
# ──────────────────────────────────────────────────────────────
# Filipino / Taglish vocabulary hint for Whisper initial_prompt.
# Priming the decoder with real Filipino words dramatically
# reduces mis-hearings like "amo" β†’ "ano".
# ──────────────────────────────────────────────────────────────
FILIPINO_VOCAB_PROMPT = (
"Ang, ang, mga, na, sa, ng, ko, mo, niya, namin, nila, "
"ano, ito, iyon, siya, kami, tayo, sila, "
"hindi, oo, wala, meron, paano, bakit, "
"kasi, diba, yung, naman, pala, talaga, "
"po, ho, kuya, ate, "
"maganda, mabuti, masaya, malaki, maliit, "
"kumain, uminom, pumunta, naglaro, natulog, "
"paaralan, bahay, trabaho, kaibigan, pamilya, "
"salamat, magandang, umaga, hapon, gabi, "
# Common English loanwords/test phrases
"hello, hi, mic, test, testing, okay, yes, no"
)
# Known Whisper misrecognitions for Filipino β€” extend as needed.
WHISPER_CORRECTIONS: dict[str, str] = {
"amo": "ano",
"cayo": "kayo",
"yong": "yung",
"cami": "kami",
"cum": "kum",
"naman naman": "naman",
# English loanword corrections
"helo": "hello",
"mike": "mic",
"test": "test", # to ensure it's not accidentally stripped
}
def post_process_transcript(text: str) -> str:
"""Fix known Whisper misrecognitions for Filipino."""
# Multi-word replacements first
for wrong, right in WHISPER_CORRECTIONS.items():
if " " in wrong:
text = re.sub(re.escape(wrong), right, text, flags=re.IGNORECASE)
words = text.split()
corrected = []
for word in words:
lower = word.lower()
if lower in WHISPER_CORRECTIONS:
corrected.append(WHISPER_CORRECTIONS[lower])
else:
corrected.append(word)
return " ".join(corrected)
def preprocess_audio(file_path: str) -> str:
"""Apply high-pass filter + normalization to reduce background noise."""
try:
sr, audio = wavfile.read(file_path)
audio = audio.astype(np.float32) / 32768.0
# High-pass at 80 Hz β€” removes low rumble / AC hum
sos = butter(5, 80, btype="highpass", fs=sr, output="sos")
audio = sosfilt(sos, audio)
# Peak-normalize to 0.95
peak = np.max(np.abs(audio))
if peak > 0:
audio = audio / peak * 0.95
processed_path = file_path.replace(".wav", "_clean.wav")
wavfile.write(processed_path, sr, (audio * 32767).astype(np.int16))
return processed_path
except Exception as e:
print(f"⚠️ Audio preprocessing failed (using raw): {e}")
return file_path
def get_local_ip():
"""Get the local IP address of this machine."""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return "127.0.0.1"
# Global async zeroconf instance
async_zeroconf = None
service_info = None
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
# Global model instances
model = None # Whisper
roberta_model = None
roberta_tokenizer = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage mDNS service registration and Model loading on startup/shutdown."""
global async_zeroconf, service_info, model, roberta_model, roberta_tokenizer
# 1. Load Whisper
if SERVICE_MODE == "audio":
print("⏳ Loading Whisper model...")
try:
print(f"πŸ”§ CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"πŸ”§ GPU Device: {torch.cuda.get_device_name(0)}")
model = WhisperModel(
"small", # 3x more accurate than 'base'
device="cuda",
compute_type="float16"
)
else:
# CPU / free HF Space β€” medium+int8 fits in ~1.5 GB RAM
print("πŸ”§ Using CPU mode (medium + int8)")
model = WhisperModel("medium", device="cpu", compute_type="int8")
print("βœ… Whisper 'medium' model loaded successfully")
except Exception as e:
print(f"❌ Failed to load Whisper model: {e}")
print("⚠️ Falling back to base/int8...")
model = WhisperModel("base", device="cpu", compute_type="int8")
else:
print("⏭️ Audio Service Mode not active, skipping Whisper.")
# 2. Load RoBERTa (Tagalog)
if SERVICE_MODE == "nlp":
print("⏳ Loading RoBERTa (Tagalog) model...")
try:
# Use jcblaise/roberta-tagalog-base for fluency/coherence
model_name = "jcblaise/roberta-tagalog-base"
roberta_tokenizer = AutoTokenizer.from_pretrained(model_name)
roberta_model = AutoModelForMaskedLM.from_pretrained(model_name)
if torch.cuda.is_available():
roberta_model.to("cuda")
roberta_model.eval() # Set to evaluation mode
print("βœ… RoBERTa model loaded successfully")
except Exception as e:
print(f"❌ Failed to load RoBERTa model: {e}")
roberta_model = None
roberta_tokenizer = None
else:
print("⏭️ NLP Service Mode not active, skipping RoBERTa.")
# Startup: Register mDNS service (skip on cloud deployments)
if IS_CLOUD:
print("☁️ Cloud deployment detected - skipping mDNS registration")
else:
local_ip = get_local_ip()
print(f"🌐 Local IP: {local_ip}")
try:
async_zeroconf = AsyncZeroconf()
service_info = ServiceInfo(
SERVICE_TYPE,
SERVICE_NAME,
addresses=[socket.inet_aton(local_ip)],
port=SERVICE_PORT,
properties={
"version": "0.2.0",
"api": "/docs",
"name": "SalitaKo Speech Coach"
},
server=f"salitako.local.",
)
await async_zeroconf.async_register_service(service_info)
print(f"πŸ“‘ mDNS service registered: {SERVICE_NAME} at {local_ip}:{SERVICE_PORT}")
except Exception as e:
print(f"⚠️ mDNS registration failed (non-fatal): {e}")
async_zeroconf = None
yield
# Shutdown: Unregister mDNS service
if async_zeroconf and service_info:
print("πŸ“‘ Unregistering mDNS service...")
try:
await async_zeroconf.async_unregister_service(service_info)
await async_zeroconf.async_close()
except Exception as e:
print(f"⚠️ mDNS unregister failed: {e}")
app = FastAPI(title="SalitaKo API", version="0.2.0", lifespan=lifespan)
@app.get("/")
async def read_root():
local_ip = get_local_ip()
return {
"message": "Welcome to SalitaKo API",
"docs_url": f"http://{local_ip}:8000/docs",
"health_check": f"http://{local_ip}:8000/health",
"local_ip": local_ip
}
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"https://*.hf.space", # Hugging Face Spaces
"*" # Allow all for development (restrict in production)
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class SessionResult(BaseModel):
student_name: str
wpm: float
fluency_score: float
filler_count: int
duration_seconds: int
@app.post("/log-session")
async def log_session_result(data: SessionResult):
"""Log session results to a local SQLite database for research analysis."""
try:
# Connect to a simple file-based DB
conn = sqlite3.connect('thesis_data.db')
cursor = conn.cursor()
# Create table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
student_name TEXT,
wpm REAL,
fluency_score REAL,
filler_count INTEGER,
duration INTEGER,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# Insert the data
cursor.execute('''
INSERT INTO results (student_name, wpm, fluency_score, filler_count, duration)
VALUES (?, ?, ?, ?, ?)
''', (data.student_name, data.wpm, data.fluency_score, data.filler_count, data.duration_seconds))
conn.commit()
conn.close()
print(f"πŸ“ Logged session for {data.student_name}")
return {"status": "logged"}
except Exception as e:
print(f"❌ Failed to log session: {e}")
return {"status": "error", "message": str(e)}
class AppConfig(BaseModel):
update_interval_seconds: int
supported_languages: list[str]
semantic_score_min: int
semantic_score_max: int
class SessionCreateResponse(BaseModel):
session_id: str
class FillerInfo(BaseModel):
count: int
fillers_detected: list[str]
class PaceInfo(BaseModel):
wpm: float
status: str # Slow, Normal, Fast
class ProsodyInfo(BaseModel):
volume_db: float | None
silence_ratio: float | None
class Feedback(BaseModel):
general: str
pacing: str
fillers: str
coherence: str
class ChunkAnalysisResponse(BaseModel):
transcript: str
wpm: float | None
filler_count: int | None
# Detailed analysis
fillers: FillerInfo | None
pacing: PaceInfo | None
prosody: ProsodyInfo | None
coherence_score: float | None
feedback: Feedback | None
message: str
# Lightweight response for real-time transcription (no analysis)
class QuickTranscriptResponse(BaseModel):
transcript: str
has_speech: bool # For auto-stop detection
message: str
@app.get("/health")
async def health_check():
return {"status": "ok"}
@app.get("/config", response_model=AppConfig)
async def get_config():
"""Return static configuration for the frontend UI."""
return AppConfig(
update_interval_seconds=3,
supported_languages=["en", "fil"],
semantic_score_min=0,
semantic_score_max=100,
)
@app.post("/sessions", response_model=SessionCreateResponse)
async def create_session():
"""Create a new speaking session and return its ID.
For now, the session is not persisted; this is a placeholder
to be backed by a database later.
"""
session_id = str(uuid.uuid4())
return SessionCreateResponse(session_id=session_id)
def detect_fillers(text: str) -> FillerInfo:
"""Detect and count common Filipino filler words."""
keywords = [
"ano", "ah", "uh", "uhm", "parang", "kasi", "ganun",
"e", "eh", "diba", "yung", "bale", "so", "like"
]
detected = []
count = 0
words = re.findall(r"\b\w+\b", text.lower())
for word in words:
if word in keywords:
detected.append(word)
count += 1
return FillerInfo(count=count, fillers_detected=detected)
def calculate_pace(transcript: str, duration_seconds: float) -> PaceInfo:
"""Calculate WPM and classify speed."""
words = len(transcript.split())
if duration_seconds <= 0:
return PaceInfo(wpm=0.0, status="Normal")
wpm = (words / duration_seconds) * 60.0
if wpm < 100:
status = "Slow"
elif wpm > 160:
status = "Fast"
else:
status = "Normal"
return PaceInfo(wpm=float(f"{wpm:.2f}"), status=status)
def analyze_prosody(segments: list, duration_seconds: float) -> ProsodyInfo:
"""Analyze prosody based on segment timings (silence detection)."""
if not segments:
return ProsodyInfo(volume_db=0.0, silence_ratio=1.0)
speech_duration = 0.0
for seg in segments:
speech_duration += (seg.end - seg.start)
silence_duration = max(0.0, duration_seconds - speech_duration)
silence_ratio = silence_duration / duration_seconds if duration_seconds > 0 else 0.0
return ProsodyInfo(volume_db=None, silence_ratio=float(f"{silence_ratio:.2f}"))
def calculate_fluency_local(text: str) -> float:
"""
Calculate a fluency score (1-10) using RoBERTa perplexity (PPL).
Lower PPL = More natural/fluent.
"""
global roberta_model, roberta_tokenizer
if not roberta_model or not roberta_tokenizer:
# Fallback to simple heuristic if model not loaded
return check_coherence_heuristic(text)
if not text.strip() or len(text.split()) < 2:
return 1.0 # Too short
try:
inputs = roberta_tokenizer(text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs = roberta_model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
ppl = torch.exp(loss).item()
# Normalize PPL to Score (1-10)
# Typical coherent text has PPL 5-50.
# >100 is likely incoherent.
# Score = 10 - (log(PPL) * factor)
# PPL 10 -> Score ~8
# PPL 100 -> Score ~3
score = max(1.0, min(10.0, 11.0 - math.log(ppl)))
return float(f"{score:.2f}")
except Exception as e:
print(f"⚠️ RoBERTa analysis failed: {e}")
return check_coherence_heuristic(text)
async def get_fluency_score(text: str) -> float:
"""Gets the fluency score, either locally (NLP mode) or remotely (Audio mode)."""
if SERVICE_MODE == "nlp":
return calculate_fluency_local(text)
if NLP_API_URL:
# Call the NLP Microservice
try:
async with httpx.AsyncClient(timeout=10.0) as client:
res = await client.post(f"{NLP_API_URL}/fluency", json={"text": text})
if res.status_code == 200:
return res.json().get("coherence_score", 5.0)
else:
print(f"⚠️ External NLP API returned {res.status_code}, falling back to heuristic.")
except Exception as e:
print(f"⚠️ Failed to connect to NLP API at {NLP_API_URL}: {e}")
# Fallback heuristic if local model missing and no external API configured/available
return check_coherence_heuristic(text)
def check_coherence_heuristic(text: str) -> float:
"""Heuristic check for coherence (Fallback)."""
score = 5.0
# Penalize very short fragments
if len(text.split()) < 3:
score -= 2.0
# Penalize excessive repetition
words = text.lower().split()
if len(words) > 4:
unique_words = set(words)
ratio = len(unique_words) / len(words)
if ratio < 0.5:
score -= 2.0
return max(1.0, score)
def generate_feedback(pace: PaceInfo, fillers: FillerInfo, prosody: ProsodyInfo, coherence_score: float) -> Feedback:
"""Generate Filipino feedback based on metrics."""
# Pacing Feedback
if pace.status == "Fast":
pacing_msg = "Medyo mabilis ang iyong pagsasalita. Subukang bagalan ng kaunti para mas maintindihan."
elif pace.status == "Slow":
pacing_msg = "Medyo mabagal. Subukang bilisan nang kaunti para mas tuloy-tuloy ang daloy."
else:
pacing_msg = "Ayos ang iyong bilis! Panatilihin ito."
# Filler Feedback
if fillers.count > 2:
filler_msg = f"Napansin ko ang paggamit ng '{fillers.fillers_detected[0]}'. Subukang mag-pause sandali sa halip na gumamit ng filler words."
else:
filler_msg = "Mahusay! Malinis ang iyong pagsasalita mula sa mga filler words."
# General/Coherence
if coherence_score < 3.0:
coherence_msg = "Medyo putol-putol ang ideya. Subukang buuin ang pangungusap."
general_msg = "Kaya mo yan! Practice pa tayo."
else:
coherence_msg = "Malinaw ang daloy ng iyong ideya."
general_msg = "Maganda ang iyong performance!"
return Feedback(
general=general_msg,
pacing=pacing_msg,
fillers=filler_msg,
coherence=coherence_msg
)
from fastapi import Form, UploadFile, File
class FluencyRequest(BaseModel):
text: str
class FluencyResponse(BaseModel):
coherence_score: float
@app.post("/fluency", response_model=FluencyResponse)
async def analyze_fluency(req: FluencyRequest):
"""External endpoint for Audio service to request fluency scoring. (NLP Mode Only)"""
score = calculate_fluency_local(req.text)
return FluencyResponse(coherence_score=score)
@app.post("/sessions/{session_id}/transcribe", response_model=QuickTranscriptResponse)
async def quick_transcribe(
session_id: str,
file: UploadFile = File(...),
prompt: str = Form("") # Optional previous context
):
"""Fast transcription endpoint with context prompt."""
audio_bytes = await file.read()
def _transcribe() -> tuple[str, bool]:
tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
try:
tmp_file.write(audio_bytes)
tmp_file.flush()
tmp_file.close()
# Preprocess: high-pass filter + normalize
audio_path = preprocess_audio(tmp_file.name)
# Combine vocab hint + previous context for better accuracy
if prompt:
initial_prompt_text = f"{FILIPINO_VOCAB_PROMPT}. {prompt}"
else:
initial_prompt_text = FILIPINO_VOCAB_PROMPT
segments, info = model.transcribe(
audio_path,
language="tl", # Force Tagalog/Taglish to prevent Spanish detection
task="transcribe",
beam_size=3,
word_timestamps=True, # Better alignment, fewer hallucinations
vad_filter=True, # Re-enable VAD to help with silence (looping)
vad_parameters=dict(min_silence_duration_ms=1000),
initial_prompt=initial_prompt_text,
condition_on_previous_text=False,
# Filters to reduce hallucinations/looping:
temperature=[0.0, 0.2, 0.4],
compression_ratio_threshold=2.4, # Filter loops
log_prob_threshold=-1.0, # Filter uncertain nonsense
no_speech_threshold=0.6, # Filter silence
)
texts = [seg.text.strip() for seg in segments if seg.text]
transcript = " ".join(texts).strip()
# Post-process: fix known misrecognitions
transcript = post_process_transcript(transcript)
# Consider any non-trivial transcript as speech
has_speech = len(transcript) > 2
return transcript, has_speech
finally:
try:
os.remove(tmp_file.name)
except OSError:
pass
try:
transcript, has_speech = await asyncio.to_thread(_transcribe)
return QuickTranscriptResponse(
transcript=transcript,
has_speech=has_speech,
message="OK" if has_speech else "No speech detected"
)
except Exception as exc:
print(f"[transcribe-error] {exc}")
return QuickTranscriptResponse(
transcript="",
has_speech=False,
message="Transcription failed"
)
@app.post("/sessions/{session_id}/audio-chunk", response_model=ChunkAnalysisResponse)
async def upload_audio_chunk(session_id: str, file: UploadFile = File(...)):
"""Full analysis endpoint - use when recording stops.
Uses a local Whisper model (via faster-whisper) so there is
no dependency on paid cloud APIs. The audio comes from the
browser as WEBM/Opus; we write it to a temporary file and let
Whisper handle decoding via ffmpeg.
"""
audio_bytes = await file.read()
async def recognize_with_whisper(audio_content: bytes) -> tuple[str, float | None, list]:
"""Run Whisper transcription in a worker thread.
Returns a pair of (transcript, duration_seconds, segments).
"""
def _call() -> tuple[str, float | None, list]:
# Use global model instance
tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
try:
tmp_file.write(audio_content)
tmp_file.flush()
tmp_file.close()
# Preprocess: high-pass filter + normalize
audio_path = preprocess_audio(tmp_file.name)
segments, info = model.transcribe(
audio_path,
language="tl", # Force Tagalog to prevent translation to English
task="transcribe", # Transcribe, don't translate to English
beam_size=5, # Better accuracy
word_timestamps=True, # Better alignment
initial_prompt=FILIPINO_VOCAB_PROMPT, # Filipino vocab hint
vad_filter=False, # Disabled to avoid cutting off speech
condition_on_previous_text=False, # Faster, no context dependency
)
segment_list = list(segments)
texts: list[str] = []
for segment in segment_list:
if segment.text:
texts.append(segment.text.strip())
transcript_text = " ".join(texts).strip()
# Post-process: fix known misrecognitions
transcript_text = post_process_transcript(transcript_text)
duration_seconds: float | None = None
# Prefer model-reported duration when available.
if getattr(info, "duration", None):
duration_seconds = float(info.duration) # type: ignore[arg-type]
elif segment_list:
start = float(segment_list[0].start or 0.0)
end = float(segment_list[-1].end or 0.0)
if end > start:
duration_seconds = end - start
return transcript_text, duration_seconds, segment_list
finally:
try:
os.remove(tmp_file.name)
except OSError:
pass
return await asyncio.to_thread(_call)
transcript = ""
duration_seconds: float | None = None
segments: list = []
try:
transcript, duration_seconds, segments = await recognize_with_whisper(audio_bytes)
if transcript:
message = "Transcription successful."
else:
message = "No clear speech detected in this chunk."
except Exception as exc: # pragma: no cover - defensive for runtime issues
# Log detailed error on the server side only.
print(f"[whisper-error] Failed to transcribe chunk for session {session_id}: {exc}")
message = "Transcription skipped for this chunk (audio too short or invalid)."
transcript = ""
# Run analysis modules
# Use fallback duration of 3.0s if undefined, to avoid division by zero
safe_duration = duration_seconds if duration_seconds and duration_seconds > 0 else 3.0
fillers = detect_fillers(transcript)
pace = calculate_pace(transcript, safe_duration)
prosody = analyze_prosody(segments, safe_duration)
# Use RoBERTa for advanced fluency scoring (or fallback to heuristic)
coherence = await get_fluency_score(transcript)
feedback = generate_feedback(pace, fillers, prosody, coherence)
return ChunkAnalysisResponse(
transcript=transcript,
wpm=pace.wpm,
filler_count=fillers.count,
fillers=fillers,
pacing=pace,
prosody=prosody,
coherence_score=coherence,
feedback=feedback,
message=message,
)
if __name__ == "__main__":
import uvicorn
# Run the FastAPI app via uvicorn directly from python
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)