sibimani's picture
Upload 2 files
376da64 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import os
import re
import json
from flask import Flask, request, jsonify
app = Flask(__name__)
script_dir = os.path.dirname(os.path.abspath(__file__))
adapter_path = os.path.join(script_dir, "lora-playwright-adapter")
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure padding token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, adapter_path)
model.eval()
# Example test goals (you can extend this to load from Excel/CSV)
def generate_action_sequence(test_goals):
full_response = []
for goal in test_goals:
prompt = f"Goal: {goal}\nReturn only one valid JSON array, no explanation.\nOutput:"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=150,
pad_token_id=tokenizer.pad_token_id,
top_p=1.0,
repetition_penalty=1.2,
do_sample=False
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract JSON array part
match = re.search(r'\[.*\]', response, re.DOTALL)
if match:
response_text = match.group(0)
try:
response_json = json.loads(response_text)
full_response.extend(response_json)
except json.JSONDecodeError:
print(f"Invalid JSON for goal: {goal}")
else:
print(f"No JSON found for goal: {goal}")
return full_response
@app.route("/")
def health():
return "OK", 200
@app.route("/generate", methods=["POST"])
def generate():
data = request.get_json()
test_goals = data.get("goals", [])
result = generate_action_sequence(test_goals)
return jsonify({"result": result})
if __name__ == "__main__":
port = int(os.environ.get("PORT", 5000))
app.run(host="0.0.0.0", port=port)