Spaces:
Sleeping
Sleeping
| import subprocess | |
| import re | |
| import html | |
| import os | |
| import logging | |
| import time | |
| import torch | |
| from typing import Optional, List, Dict | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from transformers.utils import logging as tf_logging | |
| from huggingface_hub import InferenceClient | |
| tf_logging.set_verbosity_error() | |
| class HFGenerator: | |
| def __init__(self, model_name: str = "meta-llama/Llama-3.2-3B-Instruct", use_api: bool = False): | |
| self.use_api = use_api | |
| self.model_name = model_name | |
| self.use_cuda = torch.cuda.is_available() | |
| self.device = torch.device("cuda" if self.use_cuda else "cpu") | |
| # 1. Token Retrieval | |
| self.hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| if self.use_api: | |
| print(f"[LLM] Mode: API Inference ({model_name})") | |
| self.client = InferenceClient(model=model_name, token=self.hf_token) | |
| else: | |
| print(f"[LLM] Mode: Local Load | Device: {self.device}") | |
| # 2. 4-Bit Quantization (Only for GPU) | |
| if self.use_cuda: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| quantization_config = bnb_config | |
| else: | |
| quantization_config = None | |
| # 3. Tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.hf_token) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # 4. Model Loading Logic | |
| model_kwargs = { | |
| "token": self.hf_token, | |
| "trust_remote_code": True, | |
| "device_map": "auto" if self.use_cuda else "cpu", | |
| } | |
| if quantization_config: | |
| model_kwargs["quantization_config"] = quantization_config | |
| print(f"[LLM] Loading {model_name} {'in 4-bit ' if quantization_config else ''}on {self.device}...") | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) | |
| self.max_pos = getattr(self.model.config, "max_position_embeddings", 2048) | |
| def generate(self, prompt: str, deterministic: bool = False, max_new_tokens: int = 512, temperature: float = 0.7) -> str: | |
| if self.use_api: | |
| return self.client.text_generation( | |
| prompt, max_new_tokens=max_new_tokens, temperature=temperature, stop_sequences=["</s>"] | |
| ) | |
| # Local Inference Path | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=self.max_pos - max_new_tokens | |
| ).to(self.device) | |
| # Use CPU autocast (bfloat16) for a 2x speedup on compatible processors | |
| autocast_dtype = torch.float16 if self.use_cuda else torch.bfloat16 | |
| device_type = "cuda" if self.use_cuda else "cpu" | |
| with torch.autocast(device_type=device_type, dtype=autocast_dtype): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=not deterministic, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return decoded[len(prompt):].strip() | |
| def generate_with_ollama(model_name: str, prompt: str) -> str: | |
| try: | |
| proc = subprocess.run( | |
| ["ollama", "run", model_name], | |
| input=prompt, | |
| capture_output=True, | |
| text=True, | |
| check=True, | |
| ) | |
| return proc.stdout.strip() | |
| except Exception as e: | |
| return f"Ollama error: {e}" | |
| _CLIENT_INSTANCE: Optional[InferenceClient] = None | |
| def get_client() -> InferenceClient: | |
| global _CLIENT_INSTANCE | |
| if _CLIENT_INSTANCE is None: | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| # In 2026, simply initializing the client will pick up local credentials | |
| # but passing it explicitly is safer for Spaces. | |
| _CLIENT_INSTANCE = InferenceClient(token=token) | |
| return _CLIENT_INSTANCE | |
| def generate_with_hf_api( | |
| model_name: str, | |
| prompt: str, | |
| system_message: str = "You are a thoughtful assistant.", | |
| max_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| provider: Optional[str] = None | |
| ) -> str: | |
| """ | |
| HF API implementation with detailed logging for RAG tracking. | |
| """ | |
| client = get_client() | |
| full_model_id = f"{model_name}:{provider}" if provider else model_name | |
| # Log the request details | |
| logging.info(f"π HF API Request | Model: {full_model_id} | Max Tokens: {max_tokens}") | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| start_time = time.time() | |
| try: | |
| response = client.chat.completions.create( | |
| model=full_model_id, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| stream=False | |
| ) | |
| duration = time.time() - start_time | |
| content = response.choices[0].message.content.strip() | |
| # Log successful response metrics | |
| logging.info(f"β HF API Success | Latency: {duration:.2f}s | Response Len: {len(content)} chars") | |
| return content | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| logging.error(f"β HF API Error after {duration:.2f}s: {str(e)}") | |
| # Handle specific error types if needed (e.g., rate limits) | |
| if "429" in str(e): | |
| return "HF API error: Rate limit exceeded. Please wait a moment." | |
| return f"HF API error: {str(e)}" | |
| def generate_answer( | |
| transcription_text: str, | |
| backend: str = "hf", | |
| model_name: str = "meta-llama/Llama-3.2-3B-Instruct", | |
| use_hf_api: bool = True, | |
| max_new_tokens: int = 1024, | |
| ) -> str: | |
| """ | |
| RAG Answer Generation with multi-backend support and performance logging. | |
| """ | |
| start_time = time.time() | |
| # 2. Build the prompt | |
| meeting_prompt = f""" | |
| ### ROLE: COGNITIVE EXPERT & PROJECT MANAGER | |
| Act as an expert Cognitive Editor and Business Analyst. Your goal is to transform messy, error-prone speech-to-text into a polished, high-value document. | |
| ### TASK 1: CLEANING | |
| - Remove ASR hallucinations (e.g., repeating words like "you you you" or "thank you" during silence). | |
| - Correct homophone errors (e.g., "there" vs "their", "cash" vs "cache"). | |
| - Remove filler words (ums, ahs, "you know") while preserving the speaker's original meaning. | |
| ### TASK 2: CATEGORIZATION & FORMATTING | |
| Analyze the cleaned text. Determine the "Content Type" and format the output accordingly: | |
| 1. **If it's a MEETING:** | |
| - Provide "Executive Summary," "Key Decisions," and "Action Items" (Task | Owner | Deadline). | |
| 2. **If it's a BRAINSTORM/IDEA JUGGLING:** | |
| - Organize disorganized thoughts into a "Logical Framework." | |
| - Identify the "Core Concept" and provide "Strategic Guidelines" for the next steps. | |
| 3. **If it's a LESSON/LECTURE:** | |
| - Create a "Conceptual Map." | |
| - Summarize into "Key Takeaways" and "Definitions" of complex terms used. | |
| ### TRANSCRIPTION DATA: | |
| \"\"\" | |
| {transcription_text} | |
| \"\"\" | |
| ### OUTPUT: | |
| (Start with a 1-sentence "Intent Identification" e.g., "This appears to be a brainstorming session regarding...") | |
| """ | |
| logging.info(f"Built prompt with transcription: {transcription_text[:500]}...") # Log first 500 chars | |
| try: | |
| # Route to HF API | |
| if use_hf_api: | |
| logging.info(f"π Routing to Hugging Face API (Model: {model_name})") | |
| response = generate_with_hf_api(model_name, meeting_prompt, max_tokens=max_new_tokens) | |
| # Route to Local Backends | |
| elif backend == "ollama": | |
| logging.info(f"π¦ Routing to Ollama (Model: {model_name})") | |
| response = generate_with_ollama(model_name, meeting_prompt) | |
| elif backend == "hf": | |
| logging.info(f"ποΈ Routing to Local HF Transformers (Model: {model_name})") | |
| generator = HFGenerator(model_name=model_name) | |
| response = generator.generate( | |
| meeting_prompt, | |
| max_new_tokens=max_new_tokens, | |
| deterministic=True, | |
| ) | |
| else: | |
| response = f"Unsupported backend: {backend}" | |
| logging.warning(f"β οΈ {response}") | |
| except Exception as e: | |
| logging.error(f"β Generation Failed: {str(e)}") | |
| response = f"Error during generation: {str(e)}" | |
| # 3. Log Performance Metrics | |
| duration = time.time() - start_time | |
| logging.info(f"β±οΈ Generation Complete | Time: {duration:.2f}s | Response Length: {len(response)} chars") | |
| return response | |