|
|
from tools.code_generator import generate_code |
|
|
from tools.web_search import search_web |
|
|
from tools.rag_engine import answer_from_docs |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
model.eval() |
|
|
|
|
|
def route_query(prompt): |
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
if "code:" in prompt_lower: |
|
|
return generate_code(prompt) |
|
|
elif "search:" in prompt_lower: |
|
|
return search_web(prompt) |
|
|
elif "doc:" in prompt_lower: |
|
|
return answer_from_docs(prompt) |
|
|
else: |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
outputs = model.generate(**inputs, max_new_tokens=200) |
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|