Spaces:
Sleeping
Sleeping
| # app/models/llm.py | |
| import re | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from app.utils.config import DEVICE, QWEN_MODEL_ID | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| QWEN_MODEL_ID, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| QWEN_MODEL_ID, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| def strip_markdown(text: str) -> str: | |
| text = re.sub(r"\*\*(.*?)\*\*", r"\1", text) | |
| text = re.sub(r"\*(.*?)\*", r"\1", text) | |
| text = re.sub(r"#+\s*", "", text) | |
| text = re.sub(r"-\s+", "", text) | |
| text = re.sub(r"`+", "", text) | |
| text = text.replace("---", "") | |
| return text.strip() | |
| def infer_common_name( | |
| species: str, | |
| domain: str, | |
| max_tokens: int = 16, | |
| ) -> str | None: | |
| """ | |
| Uses LLM to infer the most widely accepted English common name. | |
| Returns None if no common name exists. | |
| """ | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a taxonomy assistant. " | |
| "Return ONLY the most widely used English common name " | |
| "for the given scientific name. " | |
| "Do not explain or add extra text." | |
| "dont respond null for common name make sure u provide a common name" | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Scientific name: {species} ({domain})" | |
| } | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_ids = outputs[:, inputs.input_ids.shape[1]:] | |
| response = tokenizer.decode( | |
| generated_ids[0], | |
| skip_special_tokens=True | |
| ).strip() | |
| if not response or response.lower() == "none": | |
| return None | |
| return response | |
| def _build_messages( | |
| species: str, | |
| confidence: float, | |
| domain: str, | |
| top_k: list | None = None, | |
| ): | |
| alternatives = "" | |
| if top_k: | |
| alternatives = "\n".join( | |
| [f"{x['species']} ({x['similarity']:.2f})" for x in top_k[1:]] | |
| ) | |
| system_message = ( | |
| "You are a scientific biodiversity assistant. " | |
| "Provide factual, neutral descriptions of species. " | |
| "Do not mention instructions, rules, or formatting. " | |
| "Do not use markdown or bullet points." | |
| ) | |
| user_message = ( | |
| f"Species: {species}\n" | |
| f"Confidence: {confidence:.2f}\n\n" | |
| f"Alternative candidates:\n" | |
| f"{alternatives if alternatives else 'None'}\n\n" | |
| "Provide a factual description covering physical traits, " | |
| "natural habitat and distribution, diet or ecological role, " | |
| "conservation status, and relevant human interactions. " | |
| ) | |
| return [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message}, | |
| ] | |
| def explain_species( | |
| species: str, | |
| confidence: float, | |
| domain: str, | |
| top_k: list | None = None, | |
| max_tokens: int = 512, | |
| ): | |
| """ | |
| Returns: | |
| { | |
| "common_name": str | None, | |
| "description": str | |
| } | |
| """ | |
| COMMON_NAME_MIN_CONFIDENCE = 0.01 | |
| common_name = None | |
| if confidence >= COMMON_NAME_MIN_CONFIDENCE: | |
| common_name = infer_common_name(species, domain) | |
| messages = _build_messages(species, confidence, domain, top_k) | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **model_inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_ids = outputs[:, model_inputs.input_ids.shape[1]:] | |
| response = tokenizer.decode( | |
| generated_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| return { | |
| "common_name": common_name, | |
| "description": strip_markdown(response), | |
| } | |