|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
import os |
|
|
import re |
|
|
import json |
|
|
from flask import Flask, request, jsonify |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
adapter_path = os.path.join(script_dir, "lora-playwright-adapter") |
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.1" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, adapter_path) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
def generate_action_sequence(test_goals): |
|
|
full_response = [] |
|
|
for goal in test_goals: |
|
|
|
|
|
prompt = f"Goal: {goal}\nReturn only one valid JSON array, no explanation.\nOutput:" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
attention_mask=inputs["attention_mask"], |
|
|
max_new_tokens=150, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
top_p=1.0, |
|
|
repetition_penalty=1.2, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
match = re.search(r'\[.*\]', response, re.DOTALL) |
|
|
if match: |
|
|
response_text = match.group(0) |
|
|
try: |
|
|
response_json = json.loads(response_text) |
|
|
full_response.extend(response_json) |
|
|
except json.JSONDecodeError: |
|
|
print(f"Invalid JSON for goal: {goal}") |
|
|
else: |
|
|
print(f"No JSON found for goal: {goal}") |
|
|
|
|
|
return full_response |
|
|
|
|
|
@app.route("/") |
|
|
def health(): |
|
|
return "OK", 200 |
|
|
|
|
|
@app.route("/generate", methods=["POST"]) |
|
|
def generate(): |
|
|
data = request.get_json() |
|
|
test_goals = data.get("goals", []) |
|
|
result = generate_action_sequence(test_goals) |
|
|
return jsonify({"result": result}) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", 5000)) |
|
|
app.run(host="0.0.0.0", port=port) |
|
|
|
|
|
|