KARTE / main.py
shuhayas's picture
Merge changes from hf-deployment branch
5896fe9
import streamlit as st
import os
import sys
from dotenv import load_dotenv
import openai
import groq
from pydub import AudioSegment
import tempfile
import json
import time
# Debug information
st.write(f"Python version: {sys.version}")
st.write(f"Current working directory: {os.getcwd()}")
st.write(f"Files in current directory: {os.listdir('.')}")
st.write(f"Environment variables: {[k for k in os.environ.keys() if not k.startswith('_')]}")
# Load environment variables
load_dotenv()
st.write("Loaded .env file (if it exists)")
# Function for basic authentication
def check_password():
username = os.getenv("BASIC_AUTH_USERNAME", "admin")
password = os.getenv("BASIC_AUTH_PASSWORD", "password")
if not st.session_state.get("authenticated"):
auth_username = st.text_input("Username")
auth_password = st.text_input("Password", type="password")
if st.button("Login"):
if auth_username == username and auth_password == password:
st.session_state["authenticated"] = True
st.experimental_rerun()
else:
st.error("Invalid credentials")
return False
return True
# Page config
st.set_page_config(
page_title="KARTE - Audio Analysis",
page_icon="🎯",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom theme
st.markdown("""
<style>
.main {
background-color: #f5f5f5;
}
.stButton>button {
background-color: #ff4b4b;
color: white;
}
</style>
""", unsafe_allow_html=True)
# Check authentication
if not check_password():
st.stop()
# Initialize API clients
try:
# Try to get API keys from environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")
# Debug information about API keys
st.write(f"OpenAI API key found: {bool(openai_api_key)}")
st.write(f"Groq API key found: {bool(groq_api_key)}")
# Try to get API keys from Streamlit secrets if not in environment
if not openai_api_key and hasattr(st, 'secrets'):
st.write("Trying to get API keys from Streamlit secrets")
try:
openai_api_key = st.secrets.get("OPENAI_API_KEY")
groq_api_key = st.secrets.get("GROQ_API_KEY")
st.write(f"OpenAI API key found in secrets: {bool(openai_api_key)}")
st.write(f"Groq API key found in secrets: {bool(groq_api_key)}")
except Exception as secrets_error:
st.error(f"Error accessing Streamlit secrets: {str(secrets_error)}")
if not openai_api_key:
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
st.error("OpenAI APIキーが設定されていません。OPENAI_API_KEY環境変数を設定してください。")
st.stop()
if not groq_api_key:
st.error("Groq API key is not set. Please set the GROQ_API_KEY environment variable.")
st.error("Groq APIキーが設定されていません。GROQ_API_KEY環境変数を設定してください。")
st.stop()
# Initialize clients
st.write("Initializing API clients...")
openai_client = openai.OpenAI(api_key=openai_api_key)
groq_client = groq.Groq(api_key=groq_api_key)
st.write("API clients initialized successfully")
except Exception as e:
st.error(f"Error initializing API clients: {str(e)}")
st.error(f"Exception type: {type(e).__name__}")
st.error(f"Exception traceback: {sys.exc_info()}")
st.stop()
# Title and description
st.title("🎯 KARTE - Audio Analysis")
st.markdown("Upload an audio file for analysis and medical record generation.")
# Create tabs
tab1, tab2 = st.tabs(["Analysis Execution", "Prompt Settings"])
with tab1:
# File uploader
uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "m4a"])
if uploaded_file:
try:
# Save uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
audio_path = tmp_file.name
# Load audio file with diagnostic logging
st.write(f"Loading audio file from: {audio_path}")
st.write(f"File exists: {os.path.exists(audio_path)}")
st.write(f"File size: {os.path.getsize(audio_path)} bytes")
try:
audio = AudioSegment.from_file(audio_path)
st.write(f"Audio loaded successfully. Format: {audio.frame_rate}Hz, Channels: {audio.channels}")
duration_ms = len(audio)
st.write(f"Audio duration: {duration_ms/1000} seconds")
except Exception as audio_error:
st.error(f"Error loading audio: {str(audio_error)}")
st.error(f"This could be due to missing ffmpeg or an unsupported audio format")
st.stop()
# Process in parts if needed
chunk_size_ms = 10 * 60 * 1000 # 10 minutes
transcription = ""
with st.spinner("Transcribing audio..."):
for i in range(0, duration_ms, chunk_size_ms):
chunk = audio[i:min(i + chunk_size_ms, duration_ms)]
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as chunk_file:
chunk.export(chunk_file.name, format="wav")
with open(chunk_file.name, "rb") as audio_file:
result = openai_client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
language="ja"
)
transcription += result.text + " "
st.success("Transcription completed!")
st.text_area("Transcription", transcription, height=200)
if st.button("Generate Analysis"):
with st.spinner("Analyzing..."):
# Load prompt templates
try:
with open("prompts/templates.json", "r", encoding="utf-8") as f:
templates = json.load(f)
except:
templates = {
"style_template": """以下の会話文から、接客スタイルを分析してください:
- 言葉遣い(丁寧さ、適切性)
- 対応の質(共感性、解決力)
- 改善点
会話文:
{text}""",
"flow_template": """以下の会話文から、対応フローを分析してください:
- 導入(挨拶、用件確認)
- 展開(問題把握、解決提案)
- 結論(まとめ、次のアクション)
会話文:
{text}""",
"medical_template": """以下の会話文から、診療記録を作成してください:
- 主訴
- 現病歴
- 診察所見
- 検査結果
- 診断
- 治療計画
会話文:
{text}"""
}
# Generate analyses
analyses = {}
for analysis_type, prompt_template in templates.items():
response = groq_client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=[{"role": "user", "content": prompt_template.format(text=transcription)}]
)
analyses[analysis_type] = response.choices[0].message.content
# Display results
for title, content in analyses.items():
st.subheader(title.replace("_template", "").title())
st.write(content)
# Combine all analyses
combined_analysis = {
"transcription": transcription,
**analyses
}
# Create download button
st.download_button(
"Download Analysis Report",
data=json.dumps(combined_analysis, ensure_ascii=False, indent=2),
file_name="analysis_report.json",
mime="application/json"
)
except Exception as e:
st.error(f"Error processing audio file: {str(e)}")
finally:
# Cleanup temporary files
if 'audio_path' in locals():
os.unlink(audio_path)
with tab2:
# Prompt settings
st.subheader("Prompt Templates")
# Style analysis prompt
style_template = st.text_area(
"Style Analysis Prompt",
"""以下の会話文から、接客スタイルを分析してください:
- 言葉遣い(丁寧さ、適切性)
- 対応の質(共感性、解決力)
- 改善点
会話文:
{text}""",
height=200
)
# Flow analysis prompt
flow_template = st.text_area(
"Flow Analysis Prompt",
"""以下の会話文から、対応フローを分析してください:
- 導入(挨拶、用件確認)
- 展開(問題把握、解決提案)
- 結論(まとめ、次のアクション)
会話文:
{text}""",
height=200
)
# Medical record prompt
medical_template = st.text_area(
"Medical Record Prompt",
"""以下の会話文から、診療記録を作成してください:
- 主訴
- 現病歴
- 診察所見
- 検査結果
- 診断
- 治療計画
会話文:
{text}""",
height=200
)
if st.button("Save Templates"):
# Save templates to file
templates = {
"style_template": style_template,
"flow_template": flow_template,
"medical_template": medical_template
}
with open("prompts/templates.json", "w", encoding="utf-8") as f:
json.dump(templates, f, ensure_ascii=False, indent=2)
st.success("Templates saved successfully!")