Adedoyinjames commited on
Commit
9b08f3f
·
verified ·
1 Parent(s): 0f6dcc4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +377 -0
app.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import FileResponse, JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
+ from PIL import Image
8
+ import base64
9
+ import io
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+ import tempfile
14
+ import uvicorn
15
+
16
+ # ============================================
17
+ # IMPORTS FOR MODELS
18
+ # ============================================
19
+ from transformers import (
20
+ CLIPProcessor, CLIPModel,
21
+ AutoTokenizer, AutoModelForCausalLM,
22
+ pipeline
23
+ )
24
+ from TTS.api import TTS
25
+
26
+ # ============================================
27
+ # CONFIGURATION FOR CPU
28
+ # ============================================
29
+ DEVICE = "cpu"
30
+ TORCH_DTYPE = torch.float32
31
+
32
+ # Model names (CPU-optimized, quantized)
33
+ CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
34
+ LLM_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
35
+ TTS_MODEL_NAME = "tts_models/en/ljspeech/glow-tts" # Fast, high-quality
36
+
37
+ # ============================================
38
+ # INITIALIZE MODELS (Global, loaded once)
39
+ # ============================================
40
+ print("[INFO] Loading CLIP model...")
41
+ clip_model = CLIPModel.from_pretrained(
42
+ CLIP_MODEL_NAME,
43
+ torch_dtype=TORCH_DTYPE,
44
+ device_map=DEVICE
45
+ ).to(DEVICE).eval()
46
+ clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
47
+
48
+ print("[INFO] Loading LLM (Qwen2.5-1.5B)...")
49
+ llm_tokenizer = AutoTokenizer.from_pretrained(
50
+ LLM_MODEL_NAME,
51
+ trust_remote_code=True
52
+ )
53
+ llm_model = AutoModelForCausalLM.from_pretrained(
54
+ LLM_MODEL_NAME,
55
+ torch_dtype=TORCH_DTYPE,
56
+ device_map=DEVICE,
57
+ trust_remote_code=True,
58
+ low_cpu_mem_usage=True
59
+ ).to(DEVICE).eval()
60
+
61
+ print("[INFO] Loading TTS model (Glow-TTS)...")
62
+ tts = TTS(model_name=TTS_MODEL_NAME, gpu=False, progress_bar=False, verbose=False)
63
+
64
+ # ============================================
65
+ # FAST API APP
66
+ # ============================================
67
+ app = FastAPI(title="Coder Tutor Backend", version="1.0")
68
+
69
+ # Add CORS middleware for frontend communication
70
+ app.add_middleware(
71
+ CORSMiddleware,
72
+ allow_origins=["*"],
73
+ allow_credentials=True,
74
+ allow_methods=["*"],
75
+ allow_headers=["*"],
76
+ )
77
+
78
+ # ============================================
79
+ # PYDANTIC MODELS
80
+ # ============================================
81
+ class LearningRequest(BaseModel):
82
+ screenshot_base64: str # Base64 encoded image
83
+ user_query: str
84
+ conversation_history: list = []
85
+ speech_speed: float = 1.0 # TTS speed multiplier (0.5-2.0)
86
+
87
+ class LearningResponse(BaseModel):
88
+ guidance: str
89
+ audio_url: str
90
+ confidence: float
91
+
92
+ # ============================================
93
+ # HELPER FUNCTIONS
94
+ # ============================================
95
+ def decode_image(image_base64: str) -> Image.Image:
96
+ """Decode base64 image string to PIL Image."""
97
+ try:
98
+ image_data = base64.b64decode(image_base64)
99
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
100
+ # Resize for faster processing (CLIP works well with 224x224)
101
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
102
+ return image
103
+ except Exception as e:
104
+ raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
105
+
106
+ def analyze_screenshot_with_clip(image: Image.Image) -> dict:
107
+ """Use CLIP to understand what's on the screen."""
108
+ with torch.no_grad():
109
+ # Process image
110
+ inputs = clip_processor(
111
+ images=image,
112
+ return_tensors="pt",
113
+ padding=True
114
+ ).to(DEVICE)
115
+
116
+ image_features = clip_model.get_image_features(**inputs)
117
+
118
+ # Classify what's on screen
119
+ labels = [
120
+ "Python code editor",
121
+ "JavaScript code",
122
+ "HTML/CSS markup",
123
+ "Terminal/console output",
124
+ "Error message",
125
+ "Browser DevTools",
126
+ "IDE or text editor",
127
+ "File explorer",
128
+ "Command line",
129
+ "Documentation page"
130
+ ]
131
+
132
+ text_inputs = clip_processor(
133
+ text=labels,
134
+ return_tensors="pt",
135
+ padding=True
136
+ ).to(DEVICE)
137
+
138
+ text_features = clip_model.get_text_features(**text_inputs)
139
+
140
+ # Compute similarity
141
+ logits_per_image = image_features @ text_features.t()
142
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
143
+
144
+ top_idx = np.argmax(probs)
145
+ top_label = labels[top_idx]
146
+ confidence = float(probs[top_idx])
147
+
148
+ return {
149
+ "detected_context": top_label,
150
+ "confidence": confidence,
151
+ "all_probs": {label: float(prob) for label, prob in zip(labels, probs)}
152
+ }
153
+
154
+ def generate_beginner_guidance(
155
+ user_query: str,
156
+ screen_context: str,
157
+ conversation_history: list
158
+ ) -> str:
159
+ """Generate beginner-friendly explanation using LLM."""
160
+
161
+ # Build conversation context
162
+ history_text = ""
163
+ for msg in conversation_history[-4:]: # Last 4 messages for context
164
+ if msg.get("role") == "user":
165
+ history_text += f"User: {msg.get('query', '')}\n"
166
+ elif msg.get("role") == "assistant":
167
+ history_text += f"Assistant: {msg.get('guidance', '')}\n"
168
+
169
+ # System prompt for beginner-friendly explanations
170
+ system_prompt = """You are an expert coding tutor teaching beginners. Your rules:
171
+
172
+ 1. **Explain like they've never coded before** - define every term
173
+ 2. **Use analogies** - relate coding concepts to real-world things
174
+ 3. **Break it down** - never give full solutions, only next small step
175
+ 4. **Ask questions** - encourage thinking, don't just tell
176
+ 5. **Be encouraging** - celebrate small wins
177
+ 6. **Use simple language** - avoid jargon without explanation
178
+ 7. **Give code examples** - when relevant, show concrete examples
179
+
180
+ Current screen context: {context}
181
+ User's question/problem: {query}
182
+
183
+ Provide a step-by-step explanation of what they should do next. Keep it to 2-3 short paragraphs maximum."""
184
+
185
+ prompt = system_prompt.format(context=screen_context, query=user_query)
186
+
187
+ # Add history if available
188
+ if history_text:
189
+ prompt += f"\n\nPrevious conversation:\n{history_text}"
190
+
191
+ # Generate with Qwen
192
+ messages = [{"role": "user", "content": prompt}]
193
+
194
+ with torch.no_grad():
195
+ text = llm_tokenizer.apply_chat_template(
196
+ messages,
197
+ tokenize=False,
198
+ add_generation_prompt=True
199
+ )
200
+
201
+ model_inputs = llm_tokenizer(
202
+ text,
203
+ return_tensors="pt",
204
+ padding=True
205
+ ).to(DEVICE)
206
+
207
+ generated_ids = llm_model.generate(
208
+ **model_inputs,
209
+ max_new_tokens=256,
210
+ temperature=0.7,
211
+ top_p=0.9,
212
+ do_sample=True,
213
+ pad_token_id=llm_tokenizer.eos_token_id
214
+ )
215
+
216
+ response = llm_tokenizer.decode(
217
+ generated_ids[0][model_inputs.input_ids.shape[1]:],
218
+ skip_special_tokens=True
219
+ )
220
+
221
+ return response.strip()
222
+
223
+ def generate_speech(text: str, speed: float = 1.0) -> str:
224
+ """Generate speech using Coqui TTS and return file path."""
225
+ try:
226
+ # Create temp directory for audio
227
+ temp_dir = tempfile.gettempdir()
228
+ audio_file = os.path.join(temp_dir, "guidance_speech.wav")
229
+
230
+ # Generate speech with speed control
231
+ # Glow-TTS doesn't have built-in speed param, so we generate and modify
232
+ tts.tts_to_file(
233
+ text=text,
234
+ file_path=audio_file,
235
+ speaker=tts.speakers[0] if tts.speakers else None
236
+ )
237
+
238
+ return audio_file
239
+ except Exception as e:
240
+ print(f"[ERROR] TTS generation failed: {str(e)}")
241
+ raise HTTPException(status_code=500, detail=f"TTS failed: {str(e)}")
242
+
243
+ # ============================================
244
+ # API ENDPOINTS
245
+ # ============================================
246
+
247
+ @app.post("/learn", response_model=LearningResponse)
248
+ async def learn(request: LearningRequest):
249
+ """
250
+ Main endpoint: receive screenshot + query, return guidance + speech.
251
+ """
252
+ try:
253
+ # 1. Decode and analyze screenshot
254
+ print(f"[INFO] Decoding screenshot...")
255
+ image = decode_image(request.screenshot_base64)
256
+
257
+ print(f"[INFO] Analyzing screen with CLIP...")
258
+ screen_analysis = analyze_screenshot_with_clip(image)
259
+ screen_context = screen_analysis["detected_context"]
260
+
261
+ # 2. Generate guidance
262
+ print(f"[INFO] Generating guidance with LLM...")
263
+ guidance = generate_beginner_guidance(
264
+ user_query=request.user_query,
265
+ screen_context=screen_context,
266
+ conversation_history=request.conversation_history
267
+ )
268
+
269
+ # 3. Generate speech
270
+ print(f"[INFO] Generating speech...")
271
+ audio_file = generate_speech(guidance, speed=request.speech_speed)
272
+
273
+ # 4. Read audio and encode as base64 for response
274
+ with open(audio_file, "rb") as f:
275
+ audio_base64 = base64.b64encode(f.read()).decode()
276
+
277
+ return LearningResponse(
278
+ guidance=guidance,
279
+ audio_url=f"data:audio/wav;base64,{audio_base64}",
280
+ confidence=screen_analysis["confidence"]
281
+ )
282
+
283
+ except HTTPException:
284
+ raise
285
+ except Exception as e:
286
+ print(f"[ERROR] {str(e)}")
287
+ raise HTTPException(status_code=500, detail=str(e))
288
+
289
+ @app.post("/analyze-screenshot")
290
+ async def analyze_screenshot(request: BaseModel):
291
+ """
292
+ Endpoint to just analyze what's on screen without generating guidance.
293
+ Useful for debugging or understanding context.
294
+ """
295
+ try:
296
+ class AnalyzeRequest(BaseModel):
297
+ screenshot_base64: str
298
+
299
+ image = decode_image(request.screenshot_base64)
300
+ analysis = analyze_screenshot_with_clip(image)
301
+
302
+ return JSONResponse({
303
+ "detected_context": analysis["detected_context"],
304
+ "confidence": analysis["confidence"],
305
+ "all_detections": analysis["all_probs"]
306
+ })
307
+
308
+ except Exception as e:
309
+ raise HTTPException(status_code=500, detail=str(e))
310
+
311
+ @app.post("/tts")
312
+ async def text_to_speech(request: BaseModel):
313
+ """
314
+ Standalone TTS endpoint for converting text to speech.
315
+ Useful if you want to decouple TTS from the main learning flow.
316
+ """
317
+ try:
318
+ class TTSRequest(BaseModel):
319
+ text: str
320
+ speed: float = 1.0
321
+
322
+ audio_file = generate_speech(request.text, speed=request.speed)
323
+
324
+ with open(audio_file, "rb") as f:
325
+ audio_base64 = base64.b64encode(f.read()).decode()
326
+
327
+ return JSONResponse({
328
+ "audio_url": f"data:audio/wav;base64,{audio_base64}"
329
+ })
330
+
331
+ except Exception as e:
332
+ raise HTTPException(status_code=500, detail=str(e))
333
+
334
+ @app.get("/health")
335
+ async def health_check():
336
+ """Health check endpoint."""
337
+ return {
338
+ "status": "healthy",
339
+ "device": DEVICE,
340
+ "clip_model": CLIP_MODEL_NAME,
341
+ "llm_model": LLM_MODEL_NAME,
342
+ "tts_model": TTS_MODEL_NAME
343
+ }
344
+
345
+ @app.get("/")
346
+ async def root():
347
+ """Root endpoint with documentation."""
348
+ return {
349
+ "name": "Coder Tutor Backend",
350
+ "version": "1.0",
351
+ "endpoints": {
352
+ "POST /learn": "Main endpoint - send screenshot + query, get guidance + speech",
353
+ "POST /analyze-screenshot": "Analyze what's on screen",
354
+ "POST /tts": "Standalone text-to-speech conversion",
355
+ "GET /health": "Health check with model info"
356
+ },
357
+ "models": {
358
+ "image_encoder": CLIP_MODEL_NAME,
359
+ "llm": LLM_MODEL_NAME,
360
+ "tts": TTS_MODEL_NAME
361
+ }
362
+ }
363
+
364
+ # ============================================
365
+ # RUN SERVER
366
+ # ============================================
367
+ if __name__ == "__main__":
368
+ # Check if running on Hugging Face Spaces
369
+ space_id = os.getenv("SPACE_ID", None)
370
+ if space_id:
371
+ print(f"[INFO] Running on Hugging Face Space: {space_id}")
372
+ # HF Spaces exposes port 7860 by default
373
+ uvicorn.run(app, host="0.0.0.0", port=7860)
374
+ else:
375
+ # Local development
376
+ uvicorn.run(app, host="127.0.0.1", port=8000, reload=True)
377
+