Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Body, UploadFile, File | |
| import torch | |
| import os | |
| from pathlib import Path | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
| from pydantic import BaseModel | |
| import tempfile | |
| import hashlib | |
| import json | |
| from typing import Optional | |
| import httpx # Add this import for HTTP requests | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Define input model | |
| class TextInput(BaseModel): | |
| text: str | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| # Vous pouvez restreindre ceci à votre frontend spécifique | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Get base directory | |
| base_dir = Path(__file__).parent.absolute() | |
| # Your Hugging Face Hub username | |
| HF_USERNAME = "YassineJedidi" # Replace with your actual username | |
| # Définition des entités valides pour chaque type | |
| entites_valides = { | |
| "Tâche": {"TITRE", "DELAI", "PRIORITE"}, | |
| "Événement": {"TITRE", "DATE_HEURE"}, | |
| } | |
| # Try to load models from Hugging Face Hub | |
| try: | |
| print("Loading models from Hugging Face Hub") | |
| # Model repositories on Hugging Face | |
| ner_model_repo = f"{HF_USERNAME}/plangenieai-ner" | |
| type_model_repo = f"{HF_USERNAME}/plangenieai-type" | |
| print(f"Loading NER model (and tokenizer) from: {ner_model_repo}") | |
| print(f"Loading type model (and tokenizer) from: {type_model_repo}") | |
| # Load NER model and tokenizer from the same repo | |
| ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_repo) | |
| ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_repo) | |
| # Load type model and tokenizer from the same repo | |
| type_tokenizer = AutoTokenizer.from_pretrained(type_model_repo) | |
| type_model = AutoModelForSequenceClassification.from_pretrained( | |
| type_model_repo) | |
| except Exception as e: | |
| print(f"Error loading models from Hugging Face Hub: {e}") | |
| # Fallback to local files if available | |
| try: | |
| # Convert paths to strings with forward slashes | |
| ner_model_path = str(base_dir / "models" / | |
| "plangenieai-ner").replace("\\", "/") | |
| type_model_path = str(base_dir / "models" / | |
| "plangenieai-type").replace("\\", "/") | |
| print(f"Falling back to local models") | |
| print(f"Loading NER model (and tokenizer) from: {ner_model_path}") | |
| print(f"Loading type model (and tokenizer) from: {type_model_path}") | |
| # Load NER model and tokenizer from local files | |
| ner_tokenizer = AutoTokenizer.from_pretrained( | |
| ner_model_path, local_files_only=True) | |
| ner_model = AutoModelForTokenClassification.from_pretrained( | |
| ner_model_path, local_files_only=True) | |
| # Load type model and tokenizer from local files | |
| type_tokenizer = AutoTokenizer.from_pretrained( | |
| type_model_path, local_files_only=True) | |
| type_model = AutoModelForSequenceClassification.from_pretrained( | |
| type_model_path, local_files_only=True) | |
| except Exception as e: | |
| print(f"Error loading local models: {e}") | |
| # Fallback to base CamemBERT model from HuggingFace Hub | |
| print("Falling back to base CamemBERT model from HuggingFace Hub") | |
| ner_tokenizer = AutoTokenizer.from_pretrained("camembert-base") | |
| ner_model = AutoModelForTokenClassification.from_pretrained( | |
| "camembert-base") | |
| type_tokenizer = AutoTokenizer.from_pretrained("camembert-base") | |
| type_model = AutoModelForSequenceClassification.from_pretrained( | |
| "camembert-base") | |
| # Helper functions for tokenization | |
| def clean_text(text): | |
| if isinstance(text, str): | |
| return text.strip() | |
| return "" | |
| def find_all_occurrences(text, substring): | |
| start_positions = [] | |
| start = 0 | |
| if not substring or not isinstance(substring, str): | |
| return start_positions | |
| text_lower = text.lower() | |
| substring_lower = substring.lower() | |
| while True: | |
| start = text_lower.find(substring_lower, start) | |
| if start == -1: | |
| break | |
| is_beginning = start == 0 or not text_lower[start-1].isalnum() | |
| is_ending = (start + len(substring_lower) == len(text_lower) or | |
| not text_lower[start + len(substring_lower)].isalnum()) | |
| if is_beginning and is_ending: | |
| original_substring = text[start:start + len(substring_lower)] | |
| start_positions.append( | |
| (start, start + len(substring_lower), original_substring)) | |
| start += 1 | |
| return start_positions | |
| def tokenize_text_with_positions(text, tokenizer): | |
| """Tokenize text and return tokens with their positions""" | |
| # Use CamemBERT tokenizer | |
| tokens = tokenizer.tokenize(text) | |
| # Clean tokens and get positions | |
| clean_tokens = [] | |
| token_positions = [] | |
| current_pos = 0 | |
| for token in tokens: | |
| # Clean the token (remove special characters from tokenizer) | |
| clean_token = token.replace('▁', '').replace('##', '') | |
| clean_tokens.append(clean_token) | |
| if clean_token: | |
| pos = text.find(clean_token, current_pos) | |
| if pos != -1: | |
| token_positions.append((pos, pos + len(clean_token))) | |
| current_pos = pos + len(clean_token) | |
| else: | |
| token_positions.append( | |
| (current_pos, current_pos + len(clean_token))) | |
| current_pos += len(clean_token) | |
| else: | |
| token_positions.append((current_pos, current_pos)) | |
| return clean_tokens, token_positions | |
| # Set device (CPU or GPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Add Groq API key and URL | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| GROQ_API_URL = "https://api.groq.com/openai/v1/audio/transcriptions" | |
| ner_model = ner_model.to(device) | |
| type_model = type_model.to(device) | |
| # Retrieve label mappings | |
| id2label = ner_model.config.id2label | |
| id2type = type_model.config.id2label | |
| # Cache directory for transcriptions | |
| CACHE_DIR = Path("transcription_cache") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| def get_cache_path(audio_data: bytes) -> Path: | |
| """Generate a cache file path based on the audio content hash.""" | |
| hash_md5 = hashlib.md5(audio_data).hexdigest() | |
| return CACHE_DIR / f"{hash_md5}.json" | |
| def get_cached_transcription(audio_data: bytes) -> Optional[str]: | |
| """Get cached transcription if it exists.""" | |
| cache_path = get_cache_path(audio_data) | |
| if cache_path.exists(): | |
| try: | |
| with open(cache_path, 'r') as f: | |
| return json.load(f)['transcription'] | |
| except Exception: | |
| return None | |
| return None | |
| def save_transcription_to_cache(audio_data: bytes, transcription: str): | |
| """Save transcription to cache.""" | |
| cache_path = get_cache_path(audio_data) | |
| try: | |
| with open(cache_path, 'w') as f: | |
| json.dump({'transcription': transcription}, f) | |
| except Exception: | |
| pass # Silently fail if cache write fails | |
| def root(): | |
| return {"message": "FastAPI NLP Model is running!"} | |
| async def predict_type(input_data: TextInput): | |
| text = input_data.text | |
| inputs = type_tokenizer(text, return_tensors="pt", | |
| truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| outputs = type_model(**inputs) | |
| predicted_class_id = outputs.logits.argmax().item() | |
| predicted_type = id2type[predicted_class_id] | |
| confidence = torch.softmax(outputs.logits, dim=1).max().item() | |
| return {"type": predicted_type, "confidence": confidence} | |
| async def extract_entities(input_data: TextInput): | |
| text = input_data.text | |
| # Use the model's tokenizer for tokenization | |
| clean_tokens, token_positions = tokenize_text_with_positions( | |
| text, ner_tokenizer) | |
| # Tokenize for NER prediction | |
| inputs = ner_tokenizer(clean_tokens, is_split_into_words=True, | |
| return_tensors="pt", truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| outputs = ner_model(**inputs) | |
| predictions = outputs.logits.argmax(dim=2) | |
| entities = {} | |
| current_entity = None | |
| current_start = None | |
| current_end = None | |
| word_ids = inputs.word_ids(0) | |
| for idx, word_idx in enumerate(word_ids): | |
| if word_idx is None: | |
| continue | |
| if idx > 0 and word_ids[idx-1] == word_idx: | |
| continue | |
| prediction = predictions[0, idx].item() | |
| predicted_label = id2label[prediction] | |
| if predicted_label.startswith("B-"): | |
| if current_entity is not None: | |
| entity_type = current_entity[2:] | |
| if entity_type not in entities: | |
| entities[entity_type] = [text[current_start:current_end]] | |
| current_entity = None | |
| current_start = None | |
| current_end = None | |
| current_entity = predicted_label | |
| current_start, current_end = token_positions[word_idx] | |
| elif predicted_label.startswith("I-") and current_entity and predicted_label[2:] == current_entity[2:]: | |
| # Extend the end position to include this token | |
| _, token_end = token_positions[word_idx] | |
| current_end = token_end | |
| else: | |
| if current_entity is not None: | |
| entity_type = current_entity[2:] | |
| if entity_type not in entities: | |
| entities[entity_type] = [text[current_start:current_end]] | |
| current_entity = None | |
| current_start = None | |
| current_end = None | |
| if current_entity is not None: | |
| entity_type = current_entity[2:] | |
| if entity_type not in entities: | |
| entities[entity_type] = [text[current_start:current_end]] | |
| # Only keep the first detection, do nothing if already present | |
| return {"entities": entities} | |
| async def analyze_text(input_data: TextInput): | |
| type_result = await predict_type(input_data) | |
| text_type = type_result["type"] | |
| confidence = type_result["confidence"] | |
| raw_entities = (await extract_entities(input_data))["entities"] | |
| # Filtrage des entités selon le type détecté | |
| allowed = entites_valides.get(text_type, set()) | |
| filtered_entities = {k: v for k, v in raw_entities.items() if k in allowed} | |
| return { | |
| "type": text_type, | |
| "confidence": confidence, | |
| "entities": filtered_entities | |
| } | |
| async def transcribe_audio(file: UploadFile = File(...)): | |
| try: | |
| # Read the file content | |
| audio_data = await file.read() | |
| # Check cache first | |
| cached_transcription = get_cached_transcription(audio_data) | |
| if cached_transcription: | |
| return {"transcription": cached_transcription, "cached": True} | |
| # Save audio to a temporary file (Groq expects multipart/form-data) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| tmp.write(audio_data) | |
| tmp_path = tmp.name | |
| # Prepare request to Groq API | |
| headers = {"Authorization": f"Bearer {GROQ_API_KEY}"} | |
| data = { | |
| "model": "whisper-large-v3-turbo", | |
| "response_format": "json" | |
| } | |
| files = { | |
| "file": (os.path.basename(tmp_path), open(tmp_path, "rb"), "audio/wav") | |
| } | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post(GROQ_API_URL, headers=headers, data=data, files=files, timeout=60) | |
| # Clean up temp file | |
| os.remove(tmp_path) | |
| if response.status_code == 200: | |
| result = response.json() | |
| transcription = result.get("text", "") | |
| # Save to cache | |
| save_transcription_to_cache(audio_data, transcription) | |
| return {"transcription": transcription, "cached": False} | |
| else: | |
| print(f"Groq API error: {response.status_code} {response.text}") | |
| return {"error": "Transcription failed", "details": response.text} | |
| except Exception as e: | |
| print(f"Transcription error: {str(e)}") | |
| return {"error": "Transcription failed", "details": str(e)} | |