jatinsabari commited on
Commit
294fc5b
Β·
verified Β·
1 Parent(s): 62ccb65

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -0
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import numpy as np
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import os
7
+ from huggingface_hub import login
8
+ import tempfile
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+
12
+ # === CONFIGURATION ===
13
+ HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
14
+ MODEL_NAME = "google/gemma-2b-it"
15
+
16
+ # Login to Hugging Face
17
+ try:
18
+ if HF_TOKEN and HF_TOKEN != "your_hf_token_here":
19
+ login(token=HF_TOKEN)
20
+ print("βœ… Authenticated with Hugging Face Hub")
21
+ else:
22
+ print("⚠️ No HF_TOKEN provided, using fallback method")
23
+ except Exception as e:
24
+ print(f"⚠️ Authentication warning: {e}")
25
+
26
+ class GemmaAudioEmotionAnalyzer:
27
+ def __init__(self, model_name: str = MODEL_NAME):
28
+ self.model_name = model_name
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ print(f"πŸš€ Using device: {self.device}")
31
+
32
+ try:
33
+ print("πŸ“₯ Loading Gemma tokenizer...")
34
+ self.tokenizer = AutoTokenizer.from_pretrained(
35
+ model_name,
36
+ token=HF_TOKEN if HF_TOKEN != "your_hf_token_here" else None,
37
+ trust_remote_code=True
38
+ )
39
+
40
+ print("πŸ“₯ Loading Gemma model...")
41
+ self.model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ token=HF_TOKEN if HF_TOKEN != "your_hf_token_here" else None,
44
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
45
+ device_map="auto" if self.device == "cuda" else None,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ if self.tokenizer.pad_token is None:
50
+ self.tokenizer.pad_token = self.tokenizer.eos_token
51
+
52
+ print("βœ… Gemma model loaded successfully!")
53
+
54
+ except Exception as e:
55
+ print(f"❌ Failed to load Gemma: {e}")
56
+ print("πŸ”§ Using fallback rule-based analyzer")
57
+ self.model = None
58
+ self.tokenizer = None
59
+
60
+ def extract_fast_features(self, audio_path: str) -> dict:
61
+ """Extract minimal features quickly"""
62
+ try:
63
+ y, sr = librosa.load(audio_path, sr=16000, duration=3)
64
+
65
+ features = {
66
+ 'energy': float(np.mean(librosa.feature.rms(y=y))),
67
+ 'brightness': float(np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))),
68
+ 'pitch': float(np.median(librosa.piptrack(y=y, sr=sr)[0][librosa.piptrack(y=y, sr=sr)[0] > 0]) or 150),
69
+ 'tempo': float(librosa.beat.tempo(y=y, sr=sr)[0]),
70
+ 'speech_rate': float(np.mean(librosa.feature.zero_crossing_rate(y)))
71
+ }
72
+ return features
73
+ except Exception as e:
74
+ print(f"❌ Feature extraction error: {e}")
75
+ return {'energy': 0.05, 'brightness': 1500, 'pitch': 200, 'tempo': 100, 'speech_rate': 0.1}
76
+
77
+ def create_gemma_prompt(self, features: dict) -> str:
78
+ """Create optimized prompt for Gemma"""
79
+ prompt = f"""Analyze the emotional content from these audio features:
80
+
81
+ Audio Characteristics:
82
+ - Energy Level: {"High" if features['energy'] > 0.08 else "Low" if features['energy'] < 0.03 else "Medium"}
83
+ - Brightness: {"Bright" if features['brightness'] > 2000 else "Dark" if features['brightness'] < 1000 else "Neutral"}
84
+ - Average Pitch: {"High" if features['pitch'] > 250 else "Low" if features['pitch'] < 150 else "Medium"}
85
+ - Tempo: {"Fast" if features['tempo'] > 140 else "Slow" if features['tempo'] < 90 else "Moderate"}
86
+ - Speech Rate: {"Rapid" if features['speech_rate'] > 0.15 else "Slow" if features['speech_rate'] < 0.08 else "Normal"}
87
+
88
+ Based on these acoustic properties, identify the primary emotion. Choose ONE from: happy, sad, angry, fearful, neutral, excited, calm.
89
+
90
+ Respond in this exact format:
91
+ Emotion: [emotion]
92
+ Confidence: [high/medium/low]
93
+ Reason: [brief reason based on features]
94
+
95
+ Analysis:"""
96
+ return prompt
97
+
98
+ def generate_with_gemma(self, prompt: str) -> str:
99
+ """Generate response using Gemma with optimized settings"""
100
+ if self.model is None:
101
+ return "Emotion: neutral\nConfidence: medium\nReason: Using fallback analysis"
102
+
103
+ try:
104
+ inputs = self.tokenizer(
105
+ prompt,
106
+ return_tensors="pt",
107
+ max_length=512,
108
+ truncation=True
109
+ ).to(self.device)
110
+
111
+ with torch.no_grad():
112
+ outputs = self.model.generate(
113
+ **inputs,
114
+ max_new_tokens=100,
115
+ temperature=0.7,
116
+ do_sample=True,
117
+ top_p=0.9,
118
+ pad_token_id=self.tokenizer.eos_token_id,
119
+ repetition_penalty=1.1
120
+ )
121
+
122
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+ return response[len(prompt):].strip()
124
+
125
+ except Exception as e:
126
+ print(f"❌ Gemma generation error: {e}")
127
+ return "Emotion: neutral\nConfidence: low\nReason: Analysis unavailable"
128
+
129
+ def parse_gemma_response(self, response: str) -> dict:
130
+ """Parse Gemma's response"""
131
+ lines = response.split('\n')
132
+ result = {
133
+ 'emotion': 'neutral',
134
+ 'confidence': 'medium',
135
+ 'reason': 'No analysis provided',
136
+ 'raw_response': response
137
+ }
138
+
139
+ for line in lines:
140
+ line = line.strip()
141
+ if line.startswith('Emotion:'):
142
+ result['emotion'] = line.split(':', 1)[1].strip().lower()
143
+ elif line.startswith('Confidence:'):
144
+ result['confidence'] = line.split(':', 1)[1].strip().lower()
145
+ elif line.startswith('Reason:'):
146
+ result['reason'] = line.split(':', 1)[1].strip()
147
+
148
+ return result
149
+
150
+ def analyze_emotion(self, audio_path: str) -> dict:
151
+ """Main analysis function"""
152
+ print(f"🎡 Analyzing: {os.path.basename(audio_path)}")
153
+
154
+ features = self.extract_fast_features(audio_path)
155
+ prompt = self.create_gemma_prompt(features)
156
+
157
+ print("πŸ€– Querying Gemma...")
158
+ gemma_response = self.generate_with_gemma(prompt)
159
+
160
+ result = self.parse_gemma_response(gemma_response)
161
+ result['features'] = features
162
+
163
+ print(f"βœ… Gemma result: {result['emotion']}")
164
+ return result
165
+
166
+ # Initialize analyzer
167
+ print("πŸ”„ Initializing Gemma Audio Analyzer...")
168
+ analyzer = GemmaAudioEmotionAnalyzer()
169
+
170
+ def process_audio(audio_path: str) -> str:
171
+ """Gradio interface function"""
172
+ if not audio_path:
173
+ return "❌ Please provide an audio file"
174
+
175
+ try:
176
+ result = analyzer.analyze_emotion(audio_path)
177
+
178
+ emotion_icons = {
179
+ 'happy': '😊', 'sad': '😒', 'angry': '😠',
180
+ 'fearful': '😨', 'neutral': '😐', 'excited': '🀩', 'calm': '😌'
181
+ }
182
+
183
+ icon = emotion_icons.get(result['emotion'], '🎭')
184
+
185
+ output = f"""
186
+ {icon} **Emotion**: {result['emotion'].title()}
187
+ πŸ“Š **Confidence**: {result['confidence'].title()}
188
+ πŸ’­ **Reason**: {result['reason']}
189
+
190
+ πŸ”¬ **Audio Analysis**:
191
+ β€’ Energy: {result['features']['energy']:.3f}
192
+ β€’ Brightness: {result['features']['brightness']:.0f} Hz
193
+ β€’ Pitch: {result['features']['pitch']:.0f} Hz
194
+ β€’ Tempo: {result['features']['tempo']:.0f} BPM
195
+
196
+ πŸ€– **Powered by Google Gemma**
197
+ """
198
+ return output
199
+
200
+ except Exception as e:
201
+ return f"❌ Error: {str(e)}"
202
+
203
+ # ============ NEW: FastAPI Integration ============
204
+ app = FastAPI(title="Echo Emotion Detection API")
205
+
206
+ # Enable CORS
207
+ app.add_middleware(
208
+ CORSMiddleware,
209
+ allow_origins=["*"],
210
+ allow_methods=["*"],
211
+ allow_headers=["*"],
212
+ )
213
+
214
+ @app.get("/")
215
+ async def root():
216
+ """API Info"""
217
+ return {
218
+ "service": "Echo Emotion Detection API",
219
+ "status": "online",
220
+ "version": "1.0.0",
221
+ "endpoints": {
222
+ "analyze": "POST /api/analyze",
223
+ "health": "GET /health"
224
+ }
225
+ }
226
+
227
+ @app.get("/health")
228
+ async def health_check():
229
+ """Health check endpoint"""
230
+ return {
231
+ "status": "healthy",
232
+ "model_loaded": analyzer.model is not None
233
+ }
234
+
235
+ @app.post("/api/analyze")
236
+ async def api_analyze(audio: UploadFile = File(...)):
237
+ """
238
+ API endpoint for emotion detection
239
+
240
+ Example usage:
241
+ curl -X POST "https://your-space.hf.space/api/analyze" \
242
+ -F "audio=@voice.mp3"
243
+ """
244
+ try:
245
+ # Save uploaded file temporarily
246
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio.filename)[1]) as tmp_file:
247
+ content = await audio.read()
248
+ tmp_file.write(content)
249
+ tmp_path = tmp_file.name
250
+
251
+ # Analyze emotion using your existing analyzer
252
+ result = analyzer.analyze_emotion(tmp_path)
253
+
254
+ # Clean up temp file
255
+ os.unlink(tmp_path)
256
+
257
+ # Return structured JSON response
258
+ return {
259
+ "success": True,
260
+ "emotion": result['emotion'],
261
+ "confidence": result['confidence'],
262
+ "reason": result['reason'],
263
+ "features": result['features']
264
+ }
265
+
266
+ except Exception as e:
267
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
268
+
269
+ # Create Gradio interface
270
+ demo = gr.Interface(
271
+ fn=process_audio,
272
+ inputs=gr.Audio(
273
+ sources=["upload"],
274
+ type="filepath",
275
+ label="Upload Audio File",
276
+ max_length=10
277
+ ),
278
+ outputs=gr.Markdown(label="Gemma Emotion Analysis"),
279
+ title="🎡 Audio Emotion Analysis with Google Gemma",
280
+ description="Upload audio to analyze emotions using Google's Gemma model",
281
+ examples=[],
282
+ allow_flagging="never"
283
+ )
284
+
285
+ # Mount Gradio to FastAPI at root path
286
+ app = gr.mount_gradio_app(app, demo, path="/")
287
+
288
+ if __name__ == "__main__":
289
+ print("πŸš€ Starting Echo API Server...")
290
+ import uvicorn
291
+ uvicorn.run(app, host="0.0.0.0", port=7860)